Skip to content

Commit 8171eac

Browse files
authored
[mlir][Vector] Update patterns for flattening vector.xfer Ops (1/N) (#73522)
Updates "flatten vector" patterns to support more cases, namely Ops that read/write vectors with leading unit dims. For example: ```mlir %0 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0] ... : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<1x1x2x2xi8> ``` Currently, the `vector.transfer_read` above would not be flattened. With this change, it will be rewritten as follows: ```mlir %collapse_shape = memref.collapse_shape %arg0 [[0, 1, 2, 3]] : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>> into memref<120xi8, strided<[1], offset: ?>> %0 = vector.transfer_read %collapse_shape[%c0] ... : memref<120xi8, strided<[1], offset: ?>>, vector<4xi8> %1 = vector.shape_cast %0 : vector<4xi8> to vector<1x1x2x2xi8> ``` `hasMatchingInnerContigousShape` is generalised and renamed as `isContiguousSlice` to better match the updated functionality. A few test names are updated to better highlight what case is being exercised.
1 parent f827b95 commit 8171eac

File tree

4 files changed

+161
-42
lines changed

4 files changed

+161
-42
lines changed

mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,39 @@ Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim);
4242
/// on a 2D slice. Otherwise, returns a failure.
4343
FailureOr<std::pair<int, int>> isTranspose2DSlice(vector::TransposeOp op);
4444

45+
/// Return true if `vectorType` is a contiguous slice of `memrefType`.
46+
///
47+
/// Only the N = vectorType.getRank() trailing dims of `memrefType` are
48+
/// checked (the other dims are not relevant). Note that for `vectorType` to be
49+
/// a contiguous slice of `memrefType`, the trailing dims of the latter have
50+
/// to be contiguous - this is checked by looking at the corresponding strides.
51+
///
52+
/// There might be some restriction on the leading dim of `VectorType`:
53+
///
54+
/// Case 1. If all the trailing dims of `vectorType` match the trailing dims
55+
/// of `memrefType` then the leading dim of `vectorType` can be
56+
/// arbitrary.
57+
///
58+
/// Ex. 1.1 contiguous slice, perfect match
59+
/// vector<4x3x2xi32> from memref<5x4x3x2xi32>
60+
/// Ex. 1.2 contiguous slice, the leading dim does not match (2 != 4)
61+
/// vector<2x3x2xi32> from memref<5x4x3x2xi32>
62+
///
63+
/// Case 2. If an "internal" dim of `vectorType` does not match the
64+
/// corresponding trailing dim in `memrefType` then the remaining
65+
/// leading dims of `vectorType` have to be 1 (the first non-matching
66+
/// dim can be arbitrary).
67+
///
68+
/// Ex. 2.1 non-contiguous slice, 2 != 3 and the leading dim != <1>
69+
/// vector<2x2x2xi32> from memref<5x4x3x2xi32>
70+
/// Ex. 2.2 contiguous slice, 2 != 3 and the leading dim == <1>
71+
/// vector<1x2x2xi32> from memref<5x4x3x2xi32>
72+
/// Ex. 2.3. contiguous slice, 2 != 3 and the leading dims == <1x1>
73+
/// vector<1x1x2x2xi32> from memref<5x4x3x2xi32>
74+
/// Ex. 2.4. non-contiguous slice, 2 != 3 and the leading dims != <1x1>
75+
/// vector<2x1x2x2xi32> from memref<5x4x3x2xi32>)
76+
bool isContiguousSlice(MemRefType memrefType, VectorType vectorType);
77+
4578
} // namespace vector
4679

4780
/// Constructs a permutation map of invariant memref indices to vector

mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp

Lines changed: 2 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -491,28 +491,6 @@ class TransferWriteDropUnitDimsPattern
491491

492492
} // namespace
493493

494-
/// Return true if the memref type has its inner dimension matching the given
495-
/// shape. Otherwise return false.
496-
static int64_t hasMatchingInnerContigousShape(MemRefType memrefType,
497-
ArrayRef<int64_t> targetShape) {
498-
auto shape = memrefType.getShape();
499-
SmallVector<int64_t> strides;
500-
int64_t offset;
501-
if (!succeeded(getStridesAndOffset(memrefType, strides, offset)))
502-
return false;
503-
if (strides.back() != 1)
504-
return false;
505-
strides.pop_back();
506-
int64_t flatDim = 1;
507-
for (auto [targetDim, memrefDim, memrefStride] :
508-
llvm::reverse(llvm::zip(targetShape, shape, strides))) {
509-
flatDim *= memrefDim;
510-
if (flatDim != memrefStride || targetDim != memrefDim)
511-
return false;
512-
}
513-
return true;
514-
}
515-
516494
/// Creates a memref.collapse_shape collapsing all inner dimensions of the
517495
/// input starting at `firstDimToCollapse`.
518496
static Value collapseInnerDims(PatternRewriter &rewriter, mlir::Location loc,
@@ -572,9 +550,7 @@ class FlattenContiguousRowMajorTransferReadPattern
572550
if (vectorType.getRank() <= 1)
573551
// Already 0D/1D, nothing to do.
574552
return failure();
575-
if (!hasMatchingInnerContigousShape(
576-
sourceType,
577-
vectorType.getShape().take_back(vectorType.getRank() - 1)))
553+
if (!vector::isContiguousSlice(sourceType, vectorType))
578554
return failure();
579555
int64_t firstContiguousInnerDim =
580556
sourceType.getRank() - vectorType.getRank();
@@ -632,9 +608,7 @@ class FlattenContiguousRowMajorTransferWritePattern
632608
if (vectorType.getRank() <= 1)
633609
// Already 0D/1D, nothing to do.
634610
return failure();
635-
if (!hasMatchingInnerContigousShape(
636-
sourceType,
637-
vectorType.getShape().take_back(vectorType.getRank() - 1)))
611+
if (!vector::isContiguousSlice(sourceType, vectorType))
638612
return failure();
639613
int64_t firstContiguousInnerDim =
640614
sourceType.getRank() - vectorType.getRank();

mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,3 +249,47 @@ bool matcher::operatesOnSuperVectorsOf(Operation &op,
249249
// between parallel, reduction and possibly other cases.
250250
return ratio.has_value();
251251
}
252+
253+
bool vector::isContiguousSlice(MemRefType memrefType, VectorType vectorType) {
254+
if (vectorType.isScalable())
255+
return false;
256+
257+
ArrayRef<int64_t> vectorShape = vectorType.getShape();
258+
auto vecRank = vectorType.getRank();
259+
260+
// Extract the trailing dims and strides of the input memref
261+
auto memrefShape = memrefType.getShape().take_back(vecRank);
262+
int64_t offset;
263+
SmallVector<int64_t> stridesFull;
264+
if (!succeeded(getStridesAndOffset(memrefType, stridesFull, offset)))
265+
return false;
266+
auto strides = ArrayRef<int64_t>(stridesFull).take_back(vecRank);
267+
268+
// Cond 1: A contiguous memref will always have a unit trailing stride.
269+
if (strides.back() != 1)
270+
return false;
271+
272+
// Cond 2: Strides of a contiguous memref have to match the flattened dims.
273+
strides = strides.drop_back(1);
274+
SmallVector<int64_t> flattenedDims;
275+
for (size_t i = 1; i < memrefShape.size(); i++)
276+
flattenedDims.push_back(mlir::computeProduct(memrefShape.take_back(i)));
277+
278+
if (!llvm::equal(strides, llvm::reverse(flattenedDims)))
279+
return false;
280+
281+
// Cond 3: Compare the dims of `vectorType` against `memrefType` (in reverse).
282+
// In the most basic case, all dims will match.
283+
auto firstNonMatchingDim =
284+
std::mismatch(vectorShape.rbegin(), vectorShape.rend(),
285+
memrefShape.rbegin(), memrefShape.rend());
286+
if (firstNonMatchingDim.first == vectorShape.rend())
287+
return true;
288+
289+
// One non-matching dim is still fine, however the remaining leading dims of
290+
// `vectorType` need to be 1.
291+
SmallVector<int64_t> leadingDims(++firstNonMatchingDim.first,
292+
vectorShape.rend());
293+
294+
return llvm::all_of(leadingDims, [](auto x) { return x == 1; });
295+
}

mlir/test/Dialect/Vector/vector-transfer-flatten.mlir

Lines changed: 82 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
// RUN: mlir-opt %s -test-vector-transfer-flatten-patterns -split-input-file | FileCheck %s
22

3-
func.func @transfer_read_flattenable_with_offset(
3+
func.func @transfer_read_dims_match_contiguous(
44
%arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<5x4x3x2xi8> {
55
%c0 = arith.constant 0 : index
66
%cst = arith.constant 0 : i8
@@ -9,7 +9,7 @@ func.func @transfer_read_flattenable_with_offset(
99
return %v : vector<5x4x3x2xi8>
1010
}
1111

12-
// CHECK-LABEL: func @transfer_read_flattenable_with_offset
12+
// CHECK-LABEL: func @transfer_read_dims_match_contiguous
1313
// CHECK-SAME: %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8
1414
// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG]] {{.}}[0, 1, 2, 3]
1515
// CHECK: %[[READ1D:.+]] = vector.transfer_read %[[COLLAPSED]]
@@ -18,15 +18,53 @@ func.func @transfer_read_flattenable_with_offset(
1818

1919
// -----
2020

21-
func.func @transfer_write_flattenable_with_offset(
21+
// The shape of the memref and the vector don't match, but the vector is a
22+
// contiguous subset of the memref, so "flattenable".
23+
24+
func.func @transfer_read_dims_mismatch_contiguous(
25+
%arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> {
26+
%c0 = arith.constant 0 : index
27+
%cst = arith.constant 0 : i8
28+
%v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
29+
memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<1x1x2x2xi8>
30+
return %v : vector<1x1x2x2xi8>
31+
}
32+
33+
// CHECK-LABEL: func.func @transfer_read_dims_mismatch_contiguous(
34+
// CHECK-SAME: %[[VAL_0:.*]]: memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> {
35+
// CHECK: %[[VAL_1:.*]] = arith.constant 0 : i8
36+
// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
37+
// CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[VAL_0]] {{\[\[}}0, 1, 2, 3]] : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>> into memref<120xi8, strided<[1], offset: ?>>
38+
// CHECK: %[[VAL_4:.*]] = vector.transfer_read %[[VAL_3]]{{\[}}%[[VAL_2]]], %[[VAL_1]] {in_bounds = [true]} : memref<120xi8, strided<[1], offset: ?>>, vector<4xi8>
39+
// CHECK: %[[VAL_5:.*]] = vector.shape_cast %[[VAL_4]] : vector<4xi8> to vector<1x1x2x2xi8>
40+
// CHECK: return %[[VAL_5]] : vector<1x1x2x2xi8>
41+
42+
// -----
43+
44+
func.func @transfer_read_dims_mismatch_non_contiguous(
45+
%arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<2x1x2x2xi8> {
46+
%c0 = arith.constant 0 : index
47+
%cst = arith.constant 0 : i8
48+
%v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
49+
memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<2x1x2x2xi8>
50+
return %v : vector<2x1x2x2xi8>
51+
}
52+
53+
// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_contiguous
54+
// CHECK-NOT: memref.collapse_shape
55+
// CHECK-NOT: vector.shape_cast
56+
57+
// -----
58+
59+
func.func @transfer_write_dims_match_contiguous(
2260
%arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, %vec : vector<5x4x3x2xi8>) {
2361
%c0 = arith.constant 0 : index
2462
vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] :
2563
vector<5x4x3x2xi8>, memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>
2664
return
2765
}
2866

29-
// CHECK-LABEL: func @transfer_write_flattenable_with_offset
67+
// CHECK-LABEL: func @transfer_write_dims_match_contiguous
3068
// CHECK-SAME: %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8
3169
// CHECK-SAME: %[[VEC:[0-9a-zA-Z]+]]: vector<5x4x3x2xi8>
3270
// CHECK-DAG: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG]] {{.}}[0, 1, 2, 3]{{.}} : memref<5x4x3x2xi8, {{.+}}> into memref<120xi8, {{.+}}>
@@ -35,16 +73,48 @@ func.func @transfer_write_flattenable_with_offset(
3573

3674
// -----
3775

76+
func.func @transfer_write_dims_mismatch_contiguous(
77+
%arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, %vec : vector<1x1x2x2xi8>) {
78+
%c0 = arith.constant 0 : index
79+
vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] :
80+
vector<1x1x2x2xi8>, memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>
81+
return
82+
}
83+
84+
// CHECK-LABEL: func.func @transfer_write_dims_mismatch_contiguous
85+
// CHECK-SAME: %[[VAL_0:.*]]: memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>,
86+
// CHECK-SAME: %[[VAL_1:.*]]: vector<1x1x2x2xi8>) {
87+
// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
88+
// CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[VAL_0]] {{\[\[}}0, 1, 2, 3]] : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>> into memref<120xi8, strided<[1], offset: ?>>
89+
// CHECK: %[[VAL_4:.*]] = vector.shape_cast %[[VAL_1]] : vector<1x1x2x2xi8> to vector<4xi8>
90+
// CHECK: vector.transfer_write %[[VAL_4]], %[[VAL_3]]{{\[}}%[[VAL_2]]] {in_bounds = [true]} : vector<4xi8>, memref<120xi8, strided<[1], offset: ?>>
91+
// CHECK: return
92+
// CHECK: }
93+
94+
// -----
95+
96+
func.func @transfer_write_dims_mismatch_non_contiguous(
97+
%arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, %vec : vector<2x1x2x2xi8>) {
98+
%c0 = arith.constant 0 : index
99+
vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] :
100+
vector<2x1x2x2xi8>, memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>
101+
return
102+
}
103+
104+
// CHECK-LABEL: func.func @transfer_write_dims_mismatch_non_contiguous
105+
// CHECK-NOT: memref.collapse_shape
106+
// CHECK-NOT: vector.shape_cast
107+
108+
// -----
109+
38110
func.func @transfer_write_0d(%arg : memref<i8>, %vec : vector<i8>) {
39111
vector.transfer_write %vec, %arg[] : vector<i8>, memref<i8>
40112
return
41113
}
42114

43-
// CHECK-LABEL: func @transfer_write_0d
44-
// CHECK-SAME: %[[ARG:.+]]: memref<i8>
45-
// CHECK-SAME: %[[VEC:.+]]: vector<i8>
46-
// CHECK: vector.transfer_write %[[VEC]], %[[ARG]][] : vector<i8>, memref<i8>
47-
// CHECK: return
115+
// CHECK-LABEL: func.func @transfer_write_0d
116+
// CHECK-NOT: memref.collapse_shape
117+
// CHECK-NOT: vector.shape_cast
48118

49119
// -----
50120

@@ -54,11 +124,9 @@ func.func @transfer_read_0d(%arg : memref<i8>) -> vector<i8> {
54124
return %0 : vector<i8>
55125
}
56126

57-
// CHECK-LABEL: func @transfer_read_0d
58-
// CHECK-SAME: %[[ARG:.+]]: memref<i8>
59-
// CHECK: %[[CST:.+]] = arith.constant 0 : i8
60-
// CHECK: %[[READ:.+]] = vector.transfer_read %[[ARG]][], %[[CST]] : memref<i8>
61-
// CHECK: return %[[READ]]
127+
// CHECK-LABEL: func.func @transfer_read_0d
128+
// CHECK-NOT: memref.collapse_shape
129+
// CHECK-NOT: vector.shape_cast
62130

63131
// -----
64132

0 commit comments

Comments
 (0)