Skip to content

[mlir][vector] Don't treat memrefs with empty stride as non-contiguous #76848

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 9, 2024

Conversation

banach-space
Copy link
Contributor

As per the docs [1]:

In absence of an explicit layout, a memref is considered to have a
multi-dimensional identity affine map layout.

This patch makes sure that MemRefs with no strides (i.e. no explicit
layout) are treated as contiguous when checking whether a particular
vector is a contiguous slice of the given MemRef.

[1] https://mlir.llvm.org/docs/Dialects/Builtin/#layout

Follow-up for #76428.

@llvmbot
Copy link
Member

llvmbot commented Jan 3, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-vector

Author: Andrzej Warzyński (banach-space)

Changes

As per the docs [1]:

In absence of an explicit layout, a memref is considered to have a
multi-dimensional identity affine map layout.

This patch makes sure that MemRefs with no strides (i.e. no explicit
layout) are treated as contiguous when checking whether a particular
vector is a contiguous slice of the given MemRef.

[1] https://mlir.llvm.org/docs/Dialects/Builtin/#layout

Follow-up for #76428.


Full diff: https://github.com/llvm/llvm-project/pull/76848.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp (+15-11)
  • (modified) mlir/test/Dialect/Vector/vector-transfer-flatten.mlir (+33-15)
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index c1c0f5483a6af5..e9eb65aef6a22e 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -270,20 +270,24 @@ bool vector::isContiguousSlice(MemRefType memrefType, VectorType vectorType) {
   if (ShapedType::isDynamicShape(memrefShape))
     return false;
 
-  // Cond 1: A contiguous memref will always have a unit trailing stride.
-  if (strides.empty() || strides.back() != 1)
-    return false;
+  // Cond 1: Check whether `memrefType` is contiguous.
+  if (!strides.empty()) {
+    // Cond 1.1: A contiguous memref will always have a unit trailing stride.
+    if (strides.back() != 1)
+      return false;
 
-  // Cond 2: Strides of a contiguous memref have to match the flattened dims.
-  strides = strides.drop_back(1);
-  SmallVector<int64_t> flattenedDims;
-  for (size_t i = 1; i < memrefShape.size(); i++)
-    flattenedDims.push_back(mlir::computeProduct(memrefShape.take_back(i)));
+    // Cond 1.2: Strides of a contiguous memref have to match the flattened
+    // dims.
+    strides = strides.drop_back(1);
+    SmallVector<int64_t> flattenedDims;
+    for (size_t i = 1; i < memrefShape.size(); i++)
+      flattenedDims.push_back(mlir::computeProduct(memrefShape.take_back(i)));
 
-  if (!llvm::equal(strides, llvm::reverse(flattenedDims)))
-    return false;
+    if (!llvm::equal(strides, llvm::reverse(flattenedDims)))
+      return false;
+  }
 
-  // Cond 3: Compare the dims of `vectorType` against `memrefType` (in reverse).
+  // Cond 2: Compare the dims of `vectorType` against `memrefType` (in reverse).
   // In the most basic case, all dims will match.
   auto firstNonMatchingDim =
       std::mismatch(vectorShape.rbegin(), vectorShape.rend(),
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index ae457ea81ec5b1..79e2b97148f3f4 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -18,6 +18,24 @@ func.func @transfer_read_dims_match_contiguous(
 
 // -----
 
+func.func @transfer_read_dims_match_contiguous_empty_stride(
+    %arg : memref<5x4x3x2xi8>) -> vector<5x4x3x2xi8> {
+    %c0 = arith.constant 0 : index
+    %cst = arith.constant 0 : i8
+    %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
+      memref<5x4x3x2xi8>, vector<5x4x3x2xi8>
+    return %v : vector<5x4x3x2xi8>
+}
+
+// CHECK-LABEL: func @tansfer_read_dims_match_contiguous_empty_stride
+// CHECK-SAME:    %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8
+// CHECK:         %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG]] {{.}}[0, 1, 2, 3]
+// CHECK:         %[[READ1D:.+]] = vector.transfer_read %[[COLLAPSED]]
+// CHECK:         %[[VEC2D:.+]] = vector.shape_cast %[[READ1D]] : vector<120xi8> to vector<5x4x3x2xi8>
+// CHECK:         return %[[VEC2D]]
+
+// -----
+
 // The shape of the memref and the vector don't match, but the vector is a
 // contiguous subset of the memref, so "flattenable".
 
@@ -114,6 +132,21 @@ func.func @transfer_read_dims_mismatch_non_contiguous(
 
 // -----
 
+func.func @transfer_read_dims_mismatch_non_contiguous_empty_stride(
+    %arg : memref<5x4x3x2xi8>) -> vector<2x1x2x2xi8> {
+    %c0 = arith.constant 0 : index
+    %cst = arith.constant 0 : i8
+    %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
+      memref<5x4x3x2xi8>, vector<2x1x2x2xi8>
+    return %v : vector<2x1x2x2xi8>
+}
+
+// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_contiguous_empty_stride
+// CHECK-NOT: memref.collapse_shape
+// CHECK-NOT: vector.shape_cast
+
+// -----
+
 func.func @transfer_write_dims_match_contiguous(
       %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, %vec : vector<5x4x3x2xi8>) {
     %c0 = arith.constant 0 : index
@@ -356,18 +389,3 @@ func.func @fold_unit_dims_entirely(%arg0 : vector<8xi32>,
 // CHECK:           %[[VAL_3:.*]] = arith.muli %[[VAL_0]], %[[VAL_1]] : vector<8xi32>
 // CHECK:           %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_2]] : vector<8xi32>
 // CHECK:           return %[[VAL_4]] : vector<8xi32>
-
-// -----
-
-// This test is to make sure there is no crash for empty stride.
-func.func @stride_empty_test(%1: memref<i16>) -> vector<32x256xi16> {
-  %c0_i16 = arith.constant 0 : i16
-  %3 = vector.transfer_read %1[], %c0_i16 {permutation_map = affine_map<() -> (0, 0)>} : memref<i16>, vector<32x256xi16>
-  return %3 : vector<32x256xi16>
-
-  // CHECK-LABEL: func.func @stride_empty_test
-  // CHECK: %[[VAL:.*]] = arith.constant 0 : i16
-  // CHECK: %[[RET:.*]] = vector.transfer_read {{.*}} vector<32x256xi16>
-  // CHECK: return %[[RET]]
-  // CHECK-NOT: empty()
-}

@dcaballe
Copy link
Contributor

dcaballe commented Jan 8, 2024

Thanks! LGTM but wondering why memrefs don't actually have multi-dimensional identity affine map by default (that is not printed)?

As per the docs [1]:

```
In absence of an explicit layout, a memref is considered to have a
multi-dimensional identity affine map layout.
```

This patch makes sure that MemRefs with no strides (i.e. no explicit
layout) are treated as contiguous when checking whether a particular
vector is a contiguous slice of the given MemRef.

[1] https://mlir.llvm.org/docs/Dialects/Builtin/#layout

Follow-up for llvm#76428.
@banach-space banach-space force-pushed the andrzej/update_is_contiguous branch from 16dc059 to f9b676a Compare January 8, 2024 21:46
@banach-space
Copy link
Contributor Author

Thanks! LGTM but wondering why memrefs don't actually have multi-dimensional identity affine map by default (that is not printed)?

From the docs:

The layout is an attribute that implements MemRefLayoutAttrInterface.

So yes, I could use this instead:

 memrefType.getLayout().isIdentity();

But then I'd still need to access strides if the layout is not an identity map.

@banach-space banach-space merged commit 81df51f into llvm:main Jan 9, 2024
@nicolasvasilache
Copy link
Contributor

Let's please revert this and not reinvent the wheel in every client.
Please use and/or improve bool isLastMemrefDimUnitStride(MemRefType type); from BuiltinTypes.h.

@banach-space
Copy link
Contributor Author

Let's please revert this and not reinvent the wheel in every client. Please use and/or improve bool isLastMemrefDimUnitStride(MemRefType type); from BuiltinTypes.h.

There's a few patches involved and I'm guessing that you are referring to the overall logic rather than this change specifically?

I could re-use isLastMemrefDimUnitStride, but I think that we need a dedicated hook. Would this work:

?

banach-space added a commit to banach-space/llvm-project that referenced this pull request Jan 16, 2024
Extracts logic to check whether the trailing dim of a memref are
contiguous into a dedicated hook in BuiitinTypes.{h|cpp}.

Follow-up for llvm#76848.
@banach-space banach-space deleted the andrzej/update_is_contiguous branch January 16, 2024 18:54
justinfargnoli pushed a commit to justinfargnoli/llvm-project that referenced this pull request Jan 28, 2024
llvm#76848)

As per the docs [1]:

```
In absence of an explicit layout, a memref is considered to have a
multi-dimensional identity affine map layout.
```

This patch makes sure that MemRefs with no strides (i.e. no explicit
layout) are treated as contiguous when checking whether a particular
vector is a contiguous slice of the given MemRef.

[1] https://mlir.llvm.org/docs/Dialects/Builtin/#layout

Follow-up for llvm#76428.
banach-space added a commit that referenced this pull request Feb 17, 2024
Extracts logic from `vector::isContiguousSlice` to check whether
the trailing dim of a memref are contiguous into a dedicated hook
in BuiitinTypes.{h|cpp}.

Follow-up for #76848.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants