Skip to content

Commit abf863e

Browse files
krzysz00GeorgeARM
authored andcommitted
[mlir][MemRef][~NFC] Move getStridesAndOffset() onto layouts (llvm#138011)
This commit refactors the getStridesAndOffet() method on MemRefType to just call `MemRefLayoutAttrInterface::getStridesAndOffset(shape, strides& offset&)`, allowing downstream users and future layouts (ex, a potential contiguous layout) to implement it without needing to patch BuiltinTypes or without needing them to conform their affine maps to the canonical strided form.
1 parent fd28d98 commit abf863e

File tree

6 files changed

+169
-143
lines changed

6 files changed

+169
-143
lines changed

mlir/include/mlir/IR/BuiltinAttributeInterfaces.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,12 @@ LogicalResult
270270
verifyAffineMapAsLayout(AffineMap m, ArrayRef<int64_t> shape,
271271
function_ref<InFlightDiagnostic()> emitError);
272272

273+
// Return the strides and offsets that can be inferred from the given affine
274+
// layout map given the map and a memref shape.
275+
LogicalResult getAffineMapStridesAndOffset(AffineMap map,
276+
ArrayRef<int64_t> shape,
277+
SmallVectorImpl<int64_t> &strides,
278+
int64_t &offset);
273279
} // namespace detail
274280

275281
} // namespace mlir

mlir/include/mlir/IR/BuiltinAttributeInterfaces.td

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -509,6 +509,23 @@ def MemRefLayoutAttrInterface : AttrInterface<"MemRefLayoutAttrInterface"> {
509509
return ::mlir::detail::verifyAffineMapAsLayout($_attr.getAffineMap(),
510510
shape, emitError);
511511
}]
512+
>,
513+
514+
InterfaceMethod<
515+
[{Return the strides (using ShapedType::kDynamic for the dynamic case)
516+
that this layout corresponds to into `strides` and `offset` if such exist
517+
and can be determined from a combination of the layout and the given
518+
`shape`. If these strides cannot be inferred, return failure().
519+
The values of `strides` and `offset` are undefined on failure.}],
520+
"::llvm::LogicalResult", "getStridesAndOffset",
521+
(ins "::llvm::ArrayRef<int64_t>":$shape,
522+
"::llvm::SmallVectorImpl<int64_t>&":$strides,
523+
"int64_t&":$offset),
524+
[{}],
525+
[{
526+
return ::mlir::detail::getAffineMapStridesAndOffset(
527+
$_attr.getAffineMap(), shape, strides, offset);
528+
}]
512529
>
513530
];
514531
}

mlir/include/mlir/IR/BuiltinAttributes.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1003,7 +1003,7 @@ def Builtin_SparseElementsAttr : Builtin_Attr<
10031003

10041004
def StridedLayoutAttr : Builtin_Attr<"StridedLayout", "strided_layout",
10051005
[DeclareAttrInterfaceMethods<MemRefLayoutAttrInterface,
1006-
["verifyLayout"]>]> {
1006+
["verifyLayout", "getStridesAndOffset"]>]> {
10071007
let summary = "An Attribute representing a strided layout of a shaped type";
10081008
let description = [{
10091009
Syntax:

mlir/lib/IR/BuiltinAttributeInterfaces.cpp

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,3 +83,138 @@ LogicalResult mlir::detail::verifyAffineMapAsLayout(
8383

8484
return success();
8585
}
86+
87+
// Fallback cases for terminal dim/sym/cst that are not part of a binary op (
88+
// i.e. single term). Accumulate the AffineExpr into the existing one.
89+
static void extractStridesFromTerm(AffineExpr e,
90+
AffineExpr multiplicativeFactor,
91+
MutableArrayRef<AffineExpr> strides,
92+
AffineExpr &offset) {
93+
if (auto dim = dyn_cast<AffineDimExpr>(e))
94+
strides[dim.getPosition()] =
95+
strides[dim.getPosition()] + multiplicativeFactor;
96+
else
97+
offset = offset + e * multiplicativeFactor;
98+
}
99+
100+
/// Takes a single AffineExpr `e` and populates the `strides` array with the
101+
/// strides expressions for each dim position.
102+
/// The convention is that the strides for dimensions d0, .. dn appear in
103+
/// order to make indexing intuitive into the result.
104+
static LogicalResult extractStrides(AffineExpr e,
105+
AffineExpr multiplicativeFactor,
106+
MutableArrayRef<AffineExpr> strides,
107+
AffineExpr &offset) {
108+
auto bin = dyn_cast<AffineBinaryOpExpr>(e);
109+
if (!bin) {
110+
extractStridesFromTerm(e, multiplicativeFactor, strides, offset);
111+
return success();
112+
}
113+
114+
if (bin.getKind() == AffineExprKind::CeilDiv ||
115+
bin.getKind() == AffineExprKind::FloorDiv ||
116+
bin.getKind() == AffineExprKind::Mod)
117+
return failure();
118+
119+
if (bin.getKind() == AffineExprKind::Mul) {
120+
auto dim = dyn_cast<AffineDimExpr>(bin.getLHS());
121+
if (dim) {
122+
strides[dim.getPosition()] =
123+
strides[dim.getPosition()] + bin.getRHS() * multiplicativeFactor;
124+
return success();
125+
}
126+
// LHS and RHS may both contain complex expressions of dims. Try one path
127+
// and if it fails try the other. This is guaranteed to succeed because
128+
// only one path may have a `dim`, otherwise this is not an AffineExpr in
129+
// the first place.
130+
if (bin.getLHS().isSymbolicOrConstant())
131+
return extractStrides(bin.getRHS(), multiplicativeFactor * bin.getLHS(),
132+
strides, offset);
133+
return extractStrides(bin.getLHS(), multiplicativeFactor * bin.getRHS(),
134+
strides, offset);
135+
}
136+
137+
if (bin.getKind() == AffineExprKind::Add) {
138+
auto res1 =
139+
extractStrides(bin.getLHS(), multiplicativeFactor, strides, offset);
140+
auto res2 =
141+
extractStrides(bin.getRHS(), multiplicativeFactor, strides, offset);
142+
return success(succeeded(res1) && succeeded(res2));
143+
}
144+
145+
llvm_unreachable("unexpected binary operation");
146+
}
147+
148+
/// A stride specification is a list of integer values that are either static
149+
/// or dynamic (encoded with ShapedType::kDynamic). Strides encode
150+
/// the distance in the number of elements between successive entries along a
151+
/// particular dimension.
152+
///
153+
/// For example, `memref<42x16xf32, (64 * d0 + d1)>` specifies a view into a
154+
/// non-contiguous memory region of `42` by `16` `f32` elements in which the
155+
/// distance between two consecutive elements along the outer dimension is `1`
156+
/// and the distance between two consecutive elements along the inner dimension
157+
/// is `64`.
158+
///
159+
/// The convention is that the strides for dimensions d0, .. dn appear in
160+
/// order to make indexing intuitive into the result.
161+
static LogicalResult getStridesAndOffset(AffineMap m, ArrayRef<int64_t> shape,
162+
SmallVectorImpl<AffineExpr> &strides,
163+
AffineExpr &offset) {
164+
if (m.getNumResults() != 1 && !m.isIdentity())
165+
return failure();
166+
167+
auto zero = getAffineConstantExpr(0, m.getContext());
168+
auto one = getAffineConstantExpr(1, m.getContext());
169+
offset = zero;
170+
strides.assign(shape.size(), zero);
171+
172+
// Canonical case for empty map.
173+
if (m.isIdentity()) {
174+
// 0-D corner case, offset is already 0.
175+
if (shape.empty())
176+
return success();
177+
auto stridedExpr = makeCanonicalStridedLayoutExpr(shape, m.getContext());
178+
if (succeeded(extractStrides(stridedExpr, one, strides, offset)))
179+
return success();
180+
assert(false && "unexpected failure: extract strides in canonical layout");
181+
}
182+
183+
// Non-canonical case requires more work.
184+
auto stridedExpr =
185+
simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
186+
if (failed(extractStrides(stridedExpr, one, strides, offset))) {
187+
offset = AffineExpr();
188+
strides.clear();
189+
return failure();
190+
}
191+
192+
// Simplify results to allow folding to constants and simple checks.
193+
unsigned numDims = m.getNumDims();
194+
unsigned numSymbols = m.getNumSymbols();
195+
offset = simplifyAffineExpr(offset, numDims, numSymbols);
196+
for (auto &stride : strides)
197+
stride = simplifyAffineExpr(stride, numDims, numSymbols);
198+
199+
return success();
200+
}
201+
202+
LogicalResult mlir::detail::getAffineMapStridesAndOffset(
203+
AffineMap map, ArrayRef<int64_t> shape, SmallVectorImpl<int64_t> &strides,
204+
int64_t &offset) {
205+
AffineExpr offsetExpr;
206+
SmallVector<AffineExpr, 4> strideExprs;
207+
if (failed(::getStridesAndOffset(map, shape, strideExprs, offsetExpr)))
208+
return failure();
209+
if (auto cst = llvm::dyn_cast<AffineConstantExpr>(offsetExpr))
210+
offset = cst.getValue();
211+
else
212+
offset = ShapedType::kDynamic;
213+
for (auto e : strideExprs) {
214+
if (auto c = llvm::dyn_cast<AffineConstantExpr>(e))
215+
strides.push_back(c.getValue());
216+
else
217+
strides.push_back(ShapedType::kDynamic);
218+
}
219+
return success();
220+
}

mlir/lib/IR/BuiltinAttributes.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,15 @@ LogicalResult StridedLayoutAttr::verifyLayout(
258258
return success();
259259
}
260260

261+
LogicalResult
262+
StridedLayoutAttr::getStridesAndOffset(ArrayRef<int64_t>,
263+
SmallVectorImpl<int64_t> &strides,
264+
int64_t &offset) const {
265+
llvm::append_range(strides, getStrides());
266+
offset = getOffset();
267+
return success();
268+
}
269+
261270
//===----------------------------------------------------------------------===//
262271
// StringAttr
263272
//===----------------------------------------------------------------------===//

mlir/lib/IR/BuiltinTypes.cpp

Lines changed: 1 addition & 142 deletions
Original file line numberDiff line numberDiff line change
@@ -715,150 +715,9 @@ MemRefType MemRefType::canonicalizeStridedLayout() {
715715
return MemRefType::Builder(*this).setLayout({});
716716
}
717717

718-
// Fallback cases for terminal dim/sym/cst that are not part of a binary op (
719-
// i.e. single term). Accumulate the AffineExpr into the existing one.
720-
static void extractStridesFromTerm(AffineExpr e,
721-
AffineExpr multiplicativeFactor,
722-
MutableArrayRef<AffineExpr> strides,
723-
AffineExpr &offset) {
724-
if (auto dim = dyn_cast<AffineDimExpr>(e))
725-
strides[dim.getPosition()] =
726-
strides[dim.getPosition()] + multiplicativeFactor;
727-
else
728-
offset = offset + e * multiplicativeFactor;
729-
}
730-
731-
/// Takes a single AffineExpr `e` and populates the `strides` array with the
732-
/// strides expressions for each dim position.
733-
/// The convention is that the strides for dimensions d0, .. dn appear in
734-
/// order to make indexing intuitive into the result.
735-
static LogicalResult extractStrides(AffineExpr e,
736-
AffineExpr multiplicativeFactor,
737-
MutableArrayRef<AffineExpr> strides,
738-
AffineExpr &offset) {
739-
auto bin = dyn_cast<AffineBinaryOpExpr>(e);
740-
if (!bin) {
741-
extractStridesFromTerm(e, multiplicativeFactor, strides, offset);
742-
return success();
743-
}
744-
745-
if (bin.getKind() == AffineExprKind::CeilDiv ||
746-
bin.getKind() == AffineExprKind::FloorDiv ||
747-
bin.getKind() == AffineExprKind::Mod)
748-
return failure();
749-
750-
if (bin.getKind() == AffineExprKind::Mul) {
751-
auto dim = dyn_cast<AffineDimExpr>(bin.getLHS());
752-
if (dim) {
753-
strides[dim.getPosition()] =
754-
strides[dim.getPosition()] + bin.getRHS() * multiplicativeFactor;
755-
return success();
756-
}
757-
// LHS and RHS may both contain complex expressions of dims. Try one path
758-
// and if it fails try the other. This is guaranteed to succeed because
759-
// only one path may have a `dim`, otherwise this is not an AffineExpr in
760-
// the first place.
761-
if (bin.getLHS().isSymbolicOrConstant())
762-
return extractStrides(bin.getRHS(), multiplicativeFactor * bin.getLHS(),
763-
strides, offset);
764-
return extractStrides(bin.getLHS(), multiplicativeFactor * bin.getRHS(),
765-
strides, offset);
766-
}
767-
768-
if (bin.getKind() == AffineExprKind::Add) {
769-
auto res1 =
770-
extractStrides(bin.getLHS(), multiplicativeFactor, strides, offset);
771-
auto res2 =
772-
extractStrides(bin.getRHS(), multiplicativeFactor, strides, offset);
773-
return success(succeeded(res1) && succeeded(res2));
774-
}
775-
776-
llvm_unreachable("unexpected binary operation");
777-
}
778-
779-
/// A stride specification is a list of integer values that are either static
780-
/// or dynamic (encoded with ShapedType::kDynamic). Strides encode
781-
/// the distance in the number of elements between successive entries along a
782-
/// particular dimension.
783-
///
784-
/// For example, `memref<42x16xf32, (64 * d0 + d1)>` specifies a view into a
785-
/// non-contiguous memory region of `42` by `16` `f32` elements in which the
786-
/// distance between two consecutive elements along the outer dimension is `1`
787-
/// and the distance between two consecutive elements along the inner dimension
788-
/// is `64`.
789-
///
790-
/// The convention is that the strides for dimensions d0, .. dn appear in
791-
/// order to make indexing intuitive into the result.
792-
static LogicalResult getStridesAndOffset(MemRefType t,
793-
SmallVectorImpl<AffineExpr> &strides,
794-
AffineExpr &offset) {
795-
AffineMap m = t.getLayout().getAffineMap();
796-
797-
if (m.getNumResults() != 1 && !m.isIdentity())
798-
return failure();
799-
800-
auto zero = getAffineConstantExpr(0, t.getContext());
801-
auto one = getAffineConstantExpr(1, t.getContext());
802-
offset = zero;
803-
strides.assign(t.getRank(), zero);
804-
805-
// Canonical case for empty map.
806-
if (m.isIdentity()) {
807-
// 0-D corner case, offset is already 0.
808-
if (t.getRank() == 0)
809-
return success();
810-
auto stridedExpr =
811-
makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
812-
if (succeeded(extractStrides(stridedExpr, one, strides, offset)))
813-
return success();
814-
assert(false && "unexpected failure: extract strides in canonical layout");
815-
}
816-
817-
// Non-canonical case requires more work.
818-
auto stridedExpr =
819-
simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
820-
if (failed(extractStrides(stridedExpr, one, strides, offset))) {
821-
offset = AffineExpr();
822-
strides.clear();
823-
return failure();
824-
}
825-
826-
// Simplify results to allow folding to constants and simple checks.
827-
unsigned numDims = m.getNumDims();
828-
unsigned numSymbols = m.getNumSymbols();
829-
offset = simplifyAffineExpr(offset, numDims, numSymbols);
830-
for (auto &stride : strides)
831-
stride = simplifyAffineExpr(stride, numDims, numSymbols);
832-
833-
return success();
834-
}
835-
836718
LogicalResult MemRefType::getStridesAndOffset(SmallVectorImpl<int64_t> &strides,
837719
int64_t &offset) {
838-
// Happy path: the type uses the strided layout directly.
839-
if (auto strided = llvm::dyn_cast<StridedLayoutAttr>(getLayout())) {
840-
llvm::append_range(strides, strided.getStrides());
841-
offset = strided.getOffset();
842-
return success();
843-
}
844-
845-
// Otherwise, defer to the affine fallback as layouts are supposed to be
846-
// convertible to affine maps.
847-
AffineExpr offsetExpr;
848-
SmallVector<AffineExpr, 4> strideExprs;
849-
if (failed(::getStridesAndOffset(*this, strideExprs, offsetExpr)))
850-
return failure();
851-
if (auto cst = llvm::dyn_cast<AffineConstantExpr>(offsetExpr))
852-
offset = cst.getValue();
853-
else
854-
offset = ShapedType::kDynamic;
855-
for (auto e : strideExprs) {
856-
if (auto c = llvm::dyn_cast<AffineConstantExpr>(e))
857-
strides.push_back(c.getValue());
858-
else
859-
strides.push_back(ShapedType::kDynamic);
860-
}
861-
return success();
720+
return getLayout().getStridesAndOffset(getShape(), strides, offset);
862721
}
863722

864723
std::pair<SmallVector<int64_t>, int64_t> MemRefType::getStridesAndOffset() {

0 commit comments

Comments
 (0)