-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][vector] Don't treat memrefs with empty stride as non-contiguous #76848
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][vector] Don't treat memrefs with empty stride as non-contiguous #76848
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-vector Author: Andrzej Warzyński (banach-space) ChangesAs per the docs [1]:
This patch makes sure that MemRefs with no strides (i.e. no explicit [1] https://mlir.llvm.org/docs/Dialects/Builtin/#layout Follow-up for #76428. Full diff: https://github.com/llvm/llvm-project/pull/76848.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index c1c0f5483a6af5..e9eb65aef6a22e 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -270,20 +270,24 @@ bool vector::isContiguousSlice(MemRefType memrefType, VectorType vectorType) {
if (ShapedType::isDynamicShape(memrefShape))
return false;
- // Cond 1: A contiguous memref will always have a unit trailing stride.
- if (strides.empty() || strides.back() != 1)
- return false;
+ // Cond 1: Check whether `memrefType` is contiguous.
+ if (!strides.empty()) {
+ // Cond 1.1: A contiguous memref will always have a unit trailing stride.
+ if (strides.back() != 1)
+ return false;
- // Cond 2: Strides of a contiguous memref have to match the flattened dims.
- strides = strides.drop_back(1);
- SmallVector<int64_t> flattenedDims;
- for (size_t i = 1; i < memrefShape.size(); i++)
- flattenedDims.push_back(mlir::computeProduct(memrefShape.take_back(i)));
+ // Cond 1.2: Strides of a contiguous memref have to match the flattened
+ // dims.
+ strides = strides.drop_back(1);
+ SmallVector<int64_t> flattenedDims;
+ for (size_t i = 1; i < memrefShape.size(); i++)
+ flattenedDims.push_back(mlir::computeProduct(memrefShape.take_back(i)));
- if (!llvm::equal(strides, llvm::reverse(flattenedDims)))
- return false;
+ if (!llvm::equal(strides, llvm::reverse(flattenedDims)))
+ return false;
+ }
- // Cond 3: Compare the dims of `vectorType` against `memrefType` (in reverse).
+ // Cond 2: Compare the dims of `vectorType` against `memrefType` (in reverse).
// In the most basic case, all dims will match.
auto firstNonMatchingDim =
std::mismatch(vectorShape.rbegin(), vectorShape.rend(),
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index ae457ea81ec5b1..79e2b97148f3f4 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -18,6 +18,24 @@ func.func @transfer_read_dims_match_contiguous(
// -----
+func.func @transfer_read_dims_match_contiguous_empty_stride(
+ %arg : memref<5x4x3x2xi8>) -> vector<5x4x3x2xi8> {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0 : i8
+ %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
+ memref<5x4x3x2xi8>, vector<5x4x3x2xi8>
+ return %v : vector<5x4x3x2xi8>
+}
+
+// CHECK-LABEL: func @tansfer_read_dims_match_contiguous_empty_stride
+// CHECK-SAME: %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8
+// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG]] {{.}}[0, 1, 2, 3]
+// CHECK: %[[READ1D:.+]] = vector.transfer_read %[[COLLAPSED]]
+// CHECK: %[[VEC2D:.+]] = vector.shape_cast %[[READ1D]] : vector<120xi8> to vector<5x4x3x2xi8>
+// CHECK: return %[[VEC2D]]
+
+// -----
+
// The shape of the memref and the vector don't match, but the vector is a
// contiguous subset of the memref, so "flattenable".
@@ -114,6 +132,21 @@ func.func @transfer_read_dims_mismatch_non_contiguous(
// -----
+func.func @transfer_read_dims_mismatch_non_contiguous_empty_stride(
+ %arg : memref<5x4x3x2xi8>) -> vector<2x1x2x2xi8> {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0 : i8
+ %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
+ memref<5x4x3x2xi8>, vector<2x1x2x2xi8>
+ return %v : vector<2x1x2x2xi8>
+}
+
+// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_contiguous_empty_stride
+// CHECK-NOT: memref.collapse_shape
+// CHECK-NOT: vector.shape_cast
+
+// -----
+
func.func @transfer_write_dims_match_contiguous(
%arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, %vec : vector<5x4x3x2xi8>) {
%c0 = arith.constant 0 : index
@@ -356,18 +389,3 @@ func.func @fold_unit_dims_entirely(%arg0 : vector<8xi32>,
// CHECK: %[[VAL_3:.*]] = arith.muli %[[VAL_0]], %[[VAL_1]] : vector<8xi32>
// CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_2]] : vector<8xi32>
// CHECK: return %[[VAL_4]] : vector<8xi32>
-
-// -----
-
-// This test is to make sure there is no crash for empty stride.
-func.func @stride_empty_test(%1: memref<i16>) -> vector<32x256xi16> {
- %c0_i16 = arith.constant 0 : i16
- %3 = vector.transfer_read %1[], %c0_i16 {permutation_map = affine_map<() -> (0, 0)>} : memref<i16>, vector<32x256xi16>
- return %3 : vector<32x256xi16>
-
- // CHECK-LABEL: func.func @stride_empty_test
- // CHECK: %[[VAL:.*]] = arith.constant 0 : i16
- // CHECK: %[[RET:.*]] = vector.transfer_read {{.*}} vector<32x256xi16>
- // CHECK: return %[[RET]]
- // CHECK-NOT: empty()
-}
|
Thanks! LGTM but wondering why memrefs don't actually have |
As per the docs [1]: ``` In absence of an explicit layout, a memref is considered to have a multi-dimensional identity affine map layout. ``` This patch makes sure that MemRefs with no strides (i.e. no explicit layout) are treated as contiguous when checking whether a particular vector is a contiguous slice of the given MemRef. [1] https://mlir.llvm.org/docs/Dialects/Builtin/#layout Follow-up for llvm#76428.
16dc059
to
f9b676a
Compare
From the docs:
So yes, I could use this instead: memrefType.getLayout().isIdentity(); But then I'd still need to access strides if the layout is not an identity map. |
Let's please revert this and not reinvent the wheel in every client. |
There's a few patches involved and I'm guessing that you are referring to the overall logic rather than this change specifically? I could re-use ? |
Extracts logic to check whether the trailing dim of a memref are contiguous into a dedicated hook in BuiitinTypes.{h|cpp}. Follow-up for llvm#76848.
llvm#76848) As per the docs [1]: ``` In absence of an explicit layout, a memref is considered to have a multi-dimensional identity affine map layout. ``` This patch makes sure that MemRefs with no strides (i.e. no explicit layout) are treated as contiguous when checking whether a particular vector is a contiguous slice of the given MemRef. [1] https://mlir.llvm.org/docs/Dialects/Builtin/#layout Follow-up for llvm#76428.
Extracts logic from `vector::isContiguousSlice` to check whether the trailing dim of a memref are contiguous into a dedicated hook in BuiitinTypes.{h|cpp}. Follow-up for #76848.
As per the docs [1]:
This patch makes sure that MemRefs with no strides (i.e. no explicit
layout) are treated as contiguous when checking whether a particular
vector is a contiguous slice of the given MemRef.
[1] https://mlir.llvm.org/docs/Dialects/Builtin/#layout
Follow-up for #76428.