Skip to content

[mlir] [IR] Allow zero strides in StridedLayoutAttr #116463

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 21, 2024

Conversation

cxy-1993
Copy link
Contributor

Disabling memrefs with a stride of 0 was intended to prevent internal aliasing, but this does not address all cases : internal aliasing can still occur when the stride is less than the shape.

On the other hand, a stride of 0 can be very useful in certain scenarios. For example, in architectures that support multi-dimensional DMA, we can use memref::copy with a stride of 0 to achieve a broadcast effect.

This commit removes the restriction that strides in memrefs cannot be 0.

Disabling memrefs with a stride of 0 was intended to prevent internal aliasing,
but this does not address all cases : internal aliasing can still occur when
the stride is less than the shape.

On the other hand, a stride of 0 can be very useful in certain scenarios.
For example, in architectures that support multi-dimensional DMA, we can use
memref::copy with a stride of 0 to achieve a broadcast effect.

This commit removes the restriction that strides in memrefs cannot be 0.
@llvmbot
Copy link
Member

llvmbot commented Nov 16, 2024

@llvm/pr-subscribers-mlir-memref
@llvm/pr-subscribers-mlir-affine
@llvm/pr-subscribers-mlir-core

@llvm/pr-subscribers-mlir

Author: donald chen (cxy-1993)

Changes

Disabling memrefs with a stride of 0 was intended to prevent internal aliasing, but this does not address all cases : internal aliasing can still occur when the stride is less than the shape.

On the other hand, a stride of 0 can be very useful in certain scenarios. For example, in architectures that support multi-dimensional DMA, we can use memref::copy with a stride of 0 to achieve a broadcast effect.

This commit removes the restriction that strides in memrefs cannot be 0.


Full diff: https://github.com/llvm/llvm-project/pull/116463.diff

5 Files Affected:

  • (modified) mlir/lib/IR/BuiltinAttributes.cpp (-4)
  • (modified) mlir/lib/IR/BuiltinTypes.cpp (-14)
  • (modified) mlir/test/Dialect/Affine/memref-stride-calculation.mlir (+2-2)
  • (modified) mlir/test/Dialect/MemRef/invalid.mlir (-10)
  • (modified) mlir/test/IR/invalid-builtin-types.mlir (-5)
diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index 8861a940336133..f288dd42baaa16 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -245,9 +245,6 @@ AffineMap StridedLayoutAttr::getAffineMap() const {
 LogicalResult
 StridedLayoutAttr::verify(function_ref<InFlightDiagnostic()> emitError,
                           int64_t offset, ArrayRef<int64_t> strides) {
-  if (llvm::is_contained(strides, 0))
-    return emitError() << "strides must not be zero";
-
   return success();
 }
 
@@ -1815,7 +1812,6 @@ AffineMap mlir::makeStridedLinearLayoutMap(ArrayRef<int64_t> strides,
   for (const auto &en : llvm::enumerate(strides)) {
     auto dim = en.index();
     auto stride = en.value();
-    assert(stride != 0 && "Invalid stride specification");
     auto d = getAffineDimExpr(dim, context);
     AffineExpr mult;
     // Static case.
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 25e9f80c9963cb..c28c580690166f 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -803,20 +803,6 @@ static LogicalResult getStridesAndOffset(MemRefType t,
   for (auto &stride : strides)
     stride = simplifyAffineExpr(stride, numDims, numSymbols);
 
-  // In practice, a strided memref must be internally non-aliasing. Test
-  // against 0 as a proxy.
-  // TODO: static cases can have more advanced checks.
-  // TODO: dynamic cases would require a way to compare symbolic
-  // expressions and would probably need an affine set context propagated
-  // everywhere.
-  if (llvm::any_of(strides, [](AffineExpr e) {
-        return e == getAffineConstantExpr(0, e.getContext());
-      })) {
-    offset = AffineExpr();
-    strides.clear();
-    return failure();
-  }
-
   return success();
 }
 
diff --git a/mlir/test/Dialect/Affine/memref-stride-calculation.mlir b/mlir/test/Dialect/Affine/memref-stride-calculation.mlir
index cce1946b391e7e..29a5f5e0d5f440 100644
--- a/mlir/test/Dialect/Affine/memref-stride-calculation.mlir
+++ b/mlir/test/Dialect/Affine/memref-stride-calculation.mlir
@@ -51,9 +51,9 @@ func.func @f(%0: index) {
   %26 = memref.alloc(%0)[] : memref<?xf32, affine_map<(i)[M]->(i)>>
 // CHECK: MemRefType offset: 0 strides: 1
   %27 = memref.alloc()[%0] : memref<5xf32, affine_map<(i)[M]->(M)>>
-// CHECK: MemRefType memref<5xf32, affine_map<(d0)[s0] -> (s0)>> cannot be converted to strided form
+// CHECK: MemRefType offset: ? strides: 0
   %28 = memref.alloc()[%0] : memref<5xf32, affine_map<(i)[M]->(123)>>
-// CHECK: MemRefType memref<5xf32, affine_map<(d0)[s0] -> (123)>> cannot be converted to strided form
+// CHECK: MemRefType offset: 123 strides: 0
   %29 = memref.alloc()[%0] : memref<f32, affine_map<()[M]->(M)>>
 // CHECK: MemRefType offset: ? strides:
   %30 = memref.alloc()[%0] : memref<f32, affine_map<()[M]->(123)>>
diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index 51c4781c9022b2..f72ad48245f819 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -245,16 +245,6 @@ func.func @memref_reinterpret_cast_no_map_but_strides(%in: memref<?x?xf32>) {
 
 // -----
 
-func.func @memref_reinterpret_cast_non_strided_layout(%in: memref<?x?xf32>) {
-  // expected-error @+1 {{expected result type to have strided layout but found 'memref<9x10xf32, affine_map<(d0, d1) -> (d0)>>}}
-  %out = memref.reinterpret_cast %in to
-           offset: [0], sizes: [9, 10], strides: [42, 1]
-         : memref<?x?xf32> to memref<9x10xf32, affine_map<(d0, d1) -> (d0)>>
-  return
-}
-
-// -----
-
 func.func @memref_reshape_element_type_mismatch(
        %buf: memref<*xf32>, %shape: memref<1xi32>) {
   // expected-error @+1 {{element types of source and destination memref types should be the same}}
diff --git a/mlir/test/IR/invalid-builtin-types.mlir b/mlir/test/IR/invalid-builtin-types.mlir
index 07854a25000feb..51612446d2e6a6 100644
--- a/mlir/test/IR/invalid-builtin-types.mlir
+++ b/mlir/test/IR/invalid-builtin-types.mlir
@@ -99,11 +99,6 @@ func.func private @memref_incorrect_strided_ending() -> memref<?x?xf32, strided<
 
 // -----
 
-// expected-error @below {{strides must not be zero}}
-func.func private @memref_zero_stride() -> memref<?x?xf32, strided<[0, 0]>>
-
-// -----
-
 // expected-error @below {{expected the number of strides to match the rank}}
 func.func private @memref_strided_rank_mismatch() -> memref<?x?xf32, strided<[1]>>
 

@cxy-1993 cxy-1993 requested a review from joker-eph November 18, 2024 03:00
@cxy-1993
Copy link
Contributor Author

ping

@MaheshRavishankar
Copy link
Contributor

stride-0 memrefs are sketchy.... but I do see the use case you are reaching for. I am fine with dropping this restriction though in any of my use case I would flag a 0-stride memref as a bug. I am happy to stamp though.

cc @ftynse or @matthias-springer for some additional comments.

@cxy-1993
Copy link
Contributor Author

stride-0 memrefs are sketchy.... but I do see the use case you are reaching for. I am fine with dropping this restriction though in any of my use case I would flag a 0-stride memref as a bug. I am happy to stamp though.

cc @ftynse or @matthias-springer for some additional comments.

Thanks for your review. ping @ftynse @matthias-springer

Copy link
Member

@matthias-springer matthias-springer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No objection from my side. This also makes it a bit more consistent with dynamic strides. Strides that were "dynamic 0" were not detected.

@cxy-1993 cxy-1993 merged commit dbe159b into llvm:main Nov 21, 2024
13 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants