Skip to content

Commit 81df51f

Browse files
authored
[mlir][vector] Don't treat memrefs with empty stride as non-contiguous (#76848)
As per the docs [1]: ``` In absence of an explicit layout, a memref is considered to have a multi-dimensional identity affine map layout. ``` This patch makes sure that MemRefs with no strides (i.e. no explicit layout) are treated as contiguous when checking whether a particular vector is a contiguous slice of the given MemRef. [1] https://mlir.llvm.org/docs/Dialects/Builtin/#layout Follow-up for #76428.
1 parent 3d688d4 commit 81df51f

File tree

2 files changed

+49
-26
lines changed

2 files changed

+49
-26
lines changed

mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -264,26 +264,31 @@ bool vector::isContiguousSlice(MemRefType memrefType, VectorType vectorType) {
264264
if (!succeeded(getStridesAndOffset(memrefType, stridesFull, offset)))
265265
return false;
266266
auto strides = ArrayRef<int64_t>(stridesFull).take_back(vecRank);
267+
memrefType.getLayout().isIdentity();
267268

268269
// TODO: Add support for memref with trailing dynamic shapes. Memrefs
269270
// with leading dynamic dimensions are already supported.
270271
if (ShapedType::isDynamicShape(memrefShape))
271272
return false;
272273

273-
// Cond 1: A contiguous memref will always have a unit trailing stride.
274-
if (strides.empty() || strides.back() != 1)
275-
return false;
274+
// Cond 1: Check whether `memrefType` is contiguous.
275+
if (!strides.empty()) {
276+
// Cond 1.1: A contiguous memref will always have a unit trailing stride.
277+
if (strides.back() != 1)
278+
return false;
276279

277-
// Cond 2: Strides of a contiguous memref have to match the flattened dims.
278-
strides = strides.drop_back(1);
279-
SmallVector<int64_t> flattenedDims;
280-
for (size_t i = 1; i < memrefShape.size(); i++)
281-
flattenedDims.push_back(mlir::computeProduct(memrefShape.take_back(i)));
280+
// Cond 1.2: Strides of a contiguous memref have to match the flattened
281+
// dims.
282+
strides = strides.drop_back(1);
283+
SmallVector<int64_t> flattenedDims;
284+
for (size_t i = 1; i < memrefShape.size(); i++)
285+
flattenedDims.push_back(mlir::computeProduct(memrefShape.take_back(i)));
282286

283-
if (!llvm::equal(strides, llvm::reverse(flattenedDims)))
284-
return false;
287+
if (!llvm::equal(strides, llvm::reverse(flattenedDims)))
288+
return false;
289+
}
285290

286-
// Cond 3: Compare the dims of `vectorType` against `memrefType` (in reverse).
291+
// Cond 2: Compare the dims of `vectorType` against `memrefType` (in reverse).
287292
// In the most basic case, all dims will match.
288293
auto firstNonMatchingDim =
289294
std::mismatch(vectorShape.rbegin(), vectorShape.rend(),

mlir/test/Dialect/Vector/vector-transfer-flatten.mlir

Lines changed: 33 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,24 @@ func.func @transfer_read_dims_match_contiguous(
1818

1919
// -----
2020

21+
func.func @transfer_read_dims_match_contiguous_empty_stride(
22+
%arg : memref<5x4x3x2xi8>) -> vector<5x4x3x2xi8> {
23+
%c0 = arith.constant 0 : index
24+
%cst = arith.constant 0 : i8
25+
%v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
26+
memref<5x4x3x2xi8>, vector<5x4x3x2xi8>
27+
return %v : vector<5x4x3x2xi8>
28+
}
29+
30+
// CHECK-LABEL: func @transfer_read_dims_match_contiguous_empty_stride
31+
// CHECK-SAME: %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8
32+
// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG]] {{.}}[0, 1, 2, 3]
33+
// CHECK: %[[READ1D:.+]] = vector.transfer_read %[[COLLAPSED]]
34+
// CHECK: %[[VEC2D:.+]] = vector.shape_cast %[[READ1D]] : vector<120xi8> to vector<5x4x3x2xi8>
35+
// CHECK: return %[[VEC2D]]
36+
37+
// -----
38+
2139
// The shape of the memref and the vector don't match, but the vector is a
2240
// contiguous subset of the memref, so "flattenable".
2341

@@ -114,6 +132,21 @@ func.func @transfer_read_dims_mismatch_non_contiguous(
114132

115133
// -----
116134

135+
func.func @transfer_read_dims_mismatch_non_contiguous_empty_stride(
136+
%arg : memref<5x4x3x2xi8>) -> vector<2x1x2x2xi8> {
137+
%c0 = arith.constant 0 : index
138+
%cst = arith.constant 0 : i8
139+
%v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
140+
memref<5x4x3x2xi8>, vector<2x1x2x2xi8>
141+
return %v : vector<2x1x2x2xi8>
142+
}
143+
144+
// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_contiguous_empty_stride
145+
// CHECK-NOT: memref.collapse_shape
146+
// CHECK-NOT: vector.shape_cast
147+
148+
// -----
149+
117150
func.func @transfer_write_dims_match_contiguous(
118151
%arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, %vec : vector<5x4x3x2xi8>) {
119152
%c0 = arith.constant 0 : index
@@ -356,18 +389,3 @@ func.func @fold_unit_dims_entirely(%arg0 : vector<8xi32>,
356389
// CHECK: %[[VAL_3:.*]] = arith.muli %[[VAL_0]], %[[VAL_1]] : vector<8xi32>
357390
// CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_2]] : vector<8xi32>
358391
// CHECK: return %[[VAL_4]] : vector<8xi32>
359-
360-
// -----
361-
362-
// This test is to make sure there is no crash for empty stride.
363-
func.func @stride_empty_test(%1: memref<i16>) -> vector<32x256xi16> {
364-
%c0_i16 = arith.constant 0 : i16
365-
%3 = vector.transfer_read %1[], %c0_i16 {permutation_map = affine_map<() -> (0, 0)>} : memref<i16>, vector<32x256xi16>
366-
return %3 : vector<32x256xi16>
367-
368-
// CHECK-LABEL: func.func @stride_empty_test
369-
// CHECK: %[[VAL:.*]] = arith.constant 0 : i16
370-
// CHECK: %[[RET:.*]] = vector.transfer_read {{.*}} vector<32x256xi16>
371-
// CHECK: return %[[RET]]
372-
// CHECK-NOT: empty()
373-
}

0 commit comments

Comments
 (0)