-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir] [IR] Allow zero strides in StridedLayoutAttr #116463
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Disabling memrefs with a stride of 0 was intended to prevent internal aliasing, but this does not address all cases : internal aliasing can still occur when the stride is less than the shape. On the other hand, a stride of 0 can be very useful in certain scenarios. For example, in architectures that support multi-dimensional DMA, we can use memref::copy with a stride of 0 to achieve a broadcast effect. This commit removes the restriction that strides in memrefs cannot be 0.
@llvm/pr-subscribers-mlir-memref @llvm/pr-subscribers-mlir Author: donald chen (cxy-1993) ChangesDisabling memrefs with a stride of 0 was intended to prevent internal aliasing, but this does not address all cases : internal aliasing can still occur when the stride is less than the shape. On the other hand, a stride of 0 can be very useful in certain scenarios. For example, in architectures that support multi-dimensional DMA, we can use memref::copy with a stride of 0 to achieve a broadcast effect. This commit removes the restriction that strides in memrefs cannot be 0. Full diff: https://github.com/llvm/llvm-project/pull/116463.diff 5 Files Affected:
diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index 8861a940336133..f288dd42baaa16 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -245,9 +245,6 @@ AffineMap StridedLayoutAttr::getAffineMap() const {
LogicalResult
StridedLayoutAttr::verify(function_ref<InFlightDiagnostic()> emitError,
int64_t offset, ArrayRef<int64_t> strides) {
- if (llvm::is_contained(strides, 0))
- return emitError() << "strides must not be zero";
-
return success();
}
@@ -1815,7 +1812,6 @@ AffineMap mlir::makeStridedLinearLayoutMap(ArrayRef<int64_t> strides,
for (const auto &en : llvm::enumerate(strides)) {
auto dim = en.index();
auto stride = en.value();
- assert(stride != 0 && "Invalid stride specification");
auto d = getAffineDimExpr(dim, context);
AffineExpr mult;
// Static case.
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 25e9f80c9963cb..c28c580690166f 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -803,20 +803,6 @@ static LogicalResult getStridesAndOffset(MemRefType t,
for (auto &stride : strides)
stride = simplifyAffineExpr(stride, numDims, numSymbols);
- // In practice, a strided memref must be internally non-aliasing. Test
- // against 0 as a proxy.
- // TODO: static cases can have more advanced checks.
- // TODO: dynamic cases would require a way to compare symbolic
- // expressions and would probably need an affine set context propagated
- // everywhere.
- if (llvm::any_of(strides, [](AffineExpr e) {
- return e == getAffineConstantExpr(0, e.getContext());
- })) {
- offset = AffineExpr();
- strides.clear();
- return failure();
- }
-
return success();
}
diff --git a/mlir/test/Dialect/Affine/memref-stride-calculation.mlir b/mlir/test/Dialect/Affine/memref-stride-calculation.mlir
index cce1946b391e7e..29a5f5e0d5f440 100644
--- a/mlir/test/Dialect/Affine/memref-stride-calculation.mlir
+++ b/mlir/test/Dialect/Affine/memref-stride-calculation.mlir
@@ -51,9 +51,9 @@ func.func @f(%0: index) {
%26 = memref.alloc(%0)[] : memref<?xf32, affine_map<(i)[M]->(i)>>
// CHECK: MemRefType offset: 0 strides: 1
%27 = memref.alloc()[%0] : memref<5xf32, affine_map<(i)[M]->(M)>>
-// CHECK: MemRefType memref<5xf32, affine_map<(d0)[s0] -> (s0)>> cannot be converted to strided form
+// CHECK: MemRefType offset: ? strides: 0
%28 = memref.alloc()[%0] : memref<5xf32, affine_map<(i)[M]->(123)>>
-// CHECK: MemRefType memref<5xf32, affine_map<(d0)[s0] -> (123)>> cannot be converted to strided form
+// CHECK: MemRefType offset: 123 strides: 0
%29 = memref.alloc()[%0] : memref<f32, affine_map<()[M]->(M)>>
// CHECK: MemRefType offset: ? strides:
%30 = memref.alloc()[%0] : memref<f32, affine_map<()[M]->(123)>>
diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index 51c4781c9022b2..f72ad48245f819 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -245,16 +245,6 @@ func.func @memref_reinterpret_cast_no_map_but_strides(%in: memref<?x?xf32>) {
// -----
-func.func @memref_reinterpret_cast_non_strided_layout(%in: memref<?x?xf32>) {
- // expected-error @+1 {{expected result type to have strided layout but found 'memref<9x10xf32, affine_map<(d0, d1) -> (d0)>>}}
- %out = memref.reinterpret_cast %in to
- offset: [0], sizes: [9, 10], strides: [42, 1]
- : memref<?x?xf32> to memref<9x10xf32, affine_map<(d0, d1) -> (d0)>>
- return
-}
-
-// -----
-
func.func @memref_reshape_element_type_mismatch(
%buf: memref<*xf32>, %shape: memref<1xi32>) {
// expected-error @+1 {{element types of source and destination memref types should be the same}}
diff --git a/mlir/test/IR/invalid-builtin-types.mlir b/mlir/test/IR/invalid-builtin-types.mlir
index 07854a25000feb..51612446d2e6a6 100644
--- a/mlir/test/IR/invalid-builtin-types.mlir
+++ b/mlir/test/IR/invalid-builtin-types.mlir
@@ -99,11 +99,6 @@ func.func private @memref_incorrect_strided_ending() -> memref<?x?xf32, strided<
// -----
-// expected-error @below {{strides must not be zero}}
-func.func private @memref_zero_stride() -> memref<?x?xf32, strided<[0, 0]>>
-
-// -----
-
// expected-error @below {{expected the number of strides to match the rank}}
func.func private @memref_strided_rank_mismatch() -> memref<?x?xf32, strided<[1]>>
|
ping |
stride-0 memrefs are sketchy.... but I do see the use case you are reaching for. I am fine with dropping this restriction though in any of my use case I would flag a 0-stride memref as a bug. I am happy to stamp though. cc @ftynse or @matthias-springer for some additional comments. |
Thanks for your review. ping @ftynse @matthias-springer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No objection from my side. This also makes it a bit more consistent with dynamic strides. Strides that were "dynamic 0" were not detected.
Disabling memrefs with a stride of 0 was intended to prevent internal aliasing, but this does not address all cases : internal aliasing can still occur when the stride is less than the shape.
On the other hand, a stride of 0 can be very useful in certain scenarios. For example, in architectures that support multi-dimensional DMA, we can use memref::copy with a stride of 0 to achieve a broadcast effect.
This commit removes the restriction that strides in memrefs cannot be 0.