Skip to content

Commit 902ccc3

Browse files
committed
[mlir][Vector] Update patterns for flattening vector.xfer Ops
Updates "flatten vector" patterns to support more cases, namely Ops that read/write vectors with leading unit dims. For example: ```mlir %0 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0] ... : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<1x1x2x2xi8> ``` Currently, this `vector.transfer_read` would not be flattened. With this change, it will be transformed as follows: ```mlir %collapse_shape = memref.collapse_shape %arg0 [[0, 1, 2, 3]] : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>> into memref<120xi8, strided<[1], offset: ?>> %0 = vector.transfer_read %collapse_shape[%c0] ... : memref<120xi8, strided<[1], offset: ?>>, vector<4xi8> %1 = vector.shape_cast %0 : vector<4xi8> to vector<1x1x2x2xi8> ``` `hasMatchingInnerContigousShape` is generalised and renamed as `isContiguousSlice` to better match the updated functionality. A few test names are updated to better highlight what case is being exercised.
1 parent 79b0330 commit 902ccc3

File tree

2 files changed

+140
-31
lines changed

2 files changed

+140
-31
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp

Lines changed: 62 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -487,26 +487,75 @@ class TransferWriteDropUnitDimsPattern
487487

488488
} // namespace
489489

490-
/// Return true if the memref type has its inner dimension matching the given
491-
/// shape. Otherwise return false.
492-
static int64_t hasMatchingInnerContigousShape(MemRefType memrefType,
493-
ArrayRef<int64_t> targetShape) {
494-
auto shape = memrefType.getShape();
495-
SmallVector<int64_t> strides;
490+
/// Return true if `vectorType` is a contiguous slice of `memrefType`.
491+
///
492+
/// Compares `vectorType` against the trailing dimensions (*) of `memrefType`
493+
/// to check whether `vectorType` is a contiguous slice of `memrefType`.
494+
///
495+
/// There are two cases:
496+
///
497+
/// 1. The trailing dimensions of `memrefType` match the dimensions of
498+
/// `vectorType` excluding the front dim (the leading dim of `vectorType` does
499+
/// not matter in this case):
500+
///
501+
/// vector<2x4x3x2xi32> vs memref<5x4x3x2xi32> (contiguous slice)
502+
/// vector<2x4x2x2xi32> vs memref<5x4x3x2xi32> (non-contiguous slice)
503+
///
504+
/// 2. The trailing dimension of `memrefType` match the trailing dimensions of
505+
/// `vectorType` (i.e. at least 2 leading dims of `vectorType` don't match). The
506+
/// first dim of `vectorType` that does not match can be arbitrary, but the
507+
/// remaining leading dims have to be 1:
508+
///
509+
/// vector<1x1x2x2xi32> vs memref<5x4x3x2xi32> (contiguous slice)
510+
/// vector<2x1x2x2xi32> vs memref<5x4x3x2xi32> (non-contiguous slice)
511+
///
512+
/// In both cases `memrefType` has to be contiguous (this is checked by looking
513+
/// at strides).
514+
///
515+
/// (*) Only relevant in cases when the rank(vectorType) < rank(memrefType)
516+
/// TODO: Update
517+
static bool isContiguousSlice(MemRefType memrefType, VectorType vectorType) {
518+
519+
ArrayRef<int64_t> targetShape = vectorType.getShape();
520+
auto targetShapeTrailingDims = targetShape.drop_front(1);
521+
522+
// Not used
496523
int64_t offset;
524+
SmallVector<int64_t> strides;
497525
if (!succeeded(getStridesAndOffset(memrefType, strides, offset)))
498526
return false;
527+
528+
// Non-unit stride in the trailing dimension means that this is memref is
529+
// not contiguous.
499530
if (strides.back() != 1)
500531
return false;
501-
strides.pop_back();
532+
533+
// Do all but the leading dim of `vectorType` and the trailing dims of
534+
// `memrefType` match?
535+
bool allTrailingDimsMatch = true;
536+
537+
// The trailing dimension of `memrefType` after collapsing/flattening the
538+
// current dim. This will be a product of the leading dims, hence initialising
539+
// to 1.
502540
int64_t flatDim = 1;
503-
for (auto [targetDim, memrefDim, memrefStride] :
504-
llvm::reverse(llvm::zip(targetShape, shape, strides))) {
541+
strides.pop_back();
542+
for (auto [targetDim, memrefDim, memrefStride] : llvm::reverse(llvm::zip(
543+
targetShapeTrailingDims, memrefType.getShape(), strides))) {
505544
flatDim *= memrefDim;
506-
if (flatDim != memrefStride || targetDim != memrefDim)
545+
// If the memref stride does not match the flattened dim, then this is
546+
// memref is not contiguous.
547+
if (flatDim != memrefStride)
548+
return false;
549+
550+
// If a non-matching dim was found, then the remaining dims of `VectorType`
551+
// should be 1.
552+
if (!allTrailingDimsMatch && (targetDim != 1))
507553
return false;
554+
555+
allTrailingDimsMatch = (targetDim == memrefDim);
508556
}
509-
return true;
557+
558+
return allTrailingDimsMatch ? true : (targetShape[0] == 1);
510559
}
511560

512561
/// Creates a memref.collapse_shape collapsing all inner dimensions of the
@@ -568,9 +617,7 @@ class FlattenContiguousRowMajorTransferReadPattern
568617
if (vectorType.getRank() <= 1)
569618
// Already 0D/1D, nothing to do.
570619
return failure();
571-
if (!hasMatchingInnerContigousShape(
572-
sourceType,
573-
vectorType.getShape().take_back(vectorType.getRank() - 1)))
620+
if (!isContiguousSlice(sourceType, vectorType))
574621
return failure();
575622
int64_t firstContiguousInnerDim =
576623
sourceType.getRank() - vectorType.getRank();
@@ -628,9 +675,7 @@ class FlattenContiguousRowMajorTransferWritePattern
628675
if (vectorType.getRank() <= 1)
629676
// Already 0D/1D, nothing to do.
630677
return failure();
631-
if (!hasMatchingInnerContigousShape(
632-
sourceType,
633-
vectorType.getShape().take_back(vectorType.getRank() - 1)))
678+
if (!isContiguousSlice(sourceType, vectorType))
634679
return failure();
635680
int64_t firstContiguousInnerDim =
636681
sourceType.getRank() - vectorType.getRank();

mlir/test/Dialect/Vector/vector-transfer-flatten.mlir

Lines changed: 78 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
// RUN: mlir-opt %s -test-vector-transfer-flatten-patterns -split-input-file | FileCheck %s
22

3-
func.func @transfer_read_flattenable_with_offset(
3+
func.func @transfer_read_dims_match_contiguous(
44
%arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<5x4x3x2xi8> {
55
%c0 = arith.constant 0 : index
66
%cst = arith.constant 0 : i8
@@ -9,7 +9,7 @@ func.func @transfer_read_flattenable_with_offset(
99
return %v : vector<5x4x3x2xi8>
1010
}
1111

12-
// CHECK-LABEL: func @transfer_read_flattenable_with_offset
12+
// CHECK-LABEL: func @transfer_read_dims_match_contiguous
1313
// CHECK-SAME: %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8
1414
// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG]] {{.}}[0, 1, 2, 3]
1515
// CHECK: %[[READ1D:.+]] = vector.transfer_read %[[COLLAPSED]]
@@ -18,15 +18,52 @@ func.func @transfer_read_flattenable_with_offset(
1818

1919
// -----
2020

21-
func.func @transfer_write_flattenable_with_offset(
21+
// The shape of the memref and the vector don't match, but the vector is a
22+
// contiguous subset of the memref, so "flattenable".
23+
24+
func.func @transfer_read_dims_mismatch_contiguous(
25+
%arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> {
26+
%c0 = arith.constant 0 : index
27+
%cst = arith.constant 0 : i8
28+
%v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
29+
memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<1x1x2x2xi8>
30+
return %v : vector<1x1x2x2xi8>
31+
}
32+
33+
// CHECK-LABEL: func.func @transfer_read_dims_mismatch_contiguous(
34+
// CHECK-SAME: %[[VAL_0:.*]]: memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> {
35+
// CHECK: %[[VAL_1:.*]] = arith.constant 0 : i8
36+
// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
37+
// CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[VAL_0]] {{\[\[}}0, 1, 2, 3]] : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>> into memref<120xi8, strided<[1], offset: ?>>
38+
// CHECK: %[[VAL_4:.*]] = vector.transfer_read %[[VAL_3]]{{\[}}%[[VAL_2]]], %[[VAL_1]] {in_bounds = [true]} : memref<120xi8, strided<[1], offset: ?>>, vector<4xi8>
39+
// CHECK: %[[VAL_5:.*]] = vector.shape_cast %[[VAL_4]] : vector<4xi8> to vector<1x1x2x2xi8>
40+
// CHECK: return %[[VAL_5]] : vector<1x1x2x2xi8>
41+
42+
// -----
43+
44+
func.func @transfer_read_dims_mismatch_contiguous(
45+
%arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<2x1x2x2xi8> {
46+
%c0 = arith.constant 0 : index
47+
%cst = arith.constant 0 : i8
48+
%v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
49+
memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<2x1x2x2xi8>
50+
return %v : vector<2x1x2x2xi8>
51+
}
52+
53+
// CHECK-NOT: memref.collapse_shape
54+
// CHECK-NOT: vector.shape_cast
55+
56+
// -----
57+
58+
func.func @transfer_write_dims_match_contiguous(
2259
%arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, %vec : vector<5x4x3x2xi8>) {
2360
%c0 = arith.constant 0 : index
2461
vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] :
2562
vector<5x4x3x2xi8>, memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>
2663
return
2764
}
2865

29-
// CHECK-LABEL: func @transfer_write_flattenable_with_offset
66+
// CHECK-LABEL: func @transfer_write_dims_match_contiguous
3067
// CHECK-SAME: %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8
3168
// CHECK-SAME: %[[VEC:[0-9a-zA-Z]+]]: vector<5x4x3x2xi8>
3269
// CHECK-DAG: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG]] {{.}}[0, 1, 2, 3]{{.}} : memref<5x4x3x2xi8, {{.+}}> into memref<120xi8, {{.+}}>
@@ -35,16 +72,46 @@ func.func @transfer_write_flattenable_with_offset(
3572

3673
// -----
3774

75+
func.func @transfer_write_dims_mismatch_contiguous(
76+
%arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, %vec : vector<1x1x2x2xi8>) {
77+
%c0 = arith.constant 0 : index
78+
vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] :
79+
vector<1x1x2x2xi8>, memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>
80+
return
81+
}
82+
83+
// CHECK-LABEL: func.func @transfer_write_dims_mismatch_contiguous
84+
// CHECK-SAME: %[[VAL_0:.*]]: memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>,
85+
// CHECK-SAME: %[[VAL_1:.*]]: vector<1x1x2x2xi8>) {
86+
// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
87+
// CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[VAL_0]] {{\[\[}}0, 1, 2, 3]] : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>> into memref<120xi8, strided<[1], offset: ?>>
88+
// CHECK: %[[VAL_4:.*]] = vector.shape_cast %[[VAL_1]] : vector<1x1x2x2xi8> to vector<4xi8>
89+
// CHECK: vector.transfer_write %[[VAL_4]], %[[VAL_3]]{{\[}}%[[VAL_2]]] {in_bounds = [true]} : vector<4xi8>, memref<120xi8, strided<[1], offset: ?>>
90+
// CHECK: return
91+
// CHECK: }
92+
93+
// -----
94+
95+
func.func @transfer_write_dims_mismatch_non_contiguous(
96+
%arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, %vec : vector<2x1x2x2xi8>) {
97+
%c0 = arith.constant 0 : index
98+
vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] :
99+
vector<2x1x2x2xi8>, memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>
100+
return
101+
}
102+
103+
// CHECK-NOT: memref.collapse_shape
104+
// CHECK-NOT: vector.shape_cast
105+
106+
// -----
107+
38108
func.func @transfer_write_0d(%arg : memref<i8>, %vec : vector<i8>) {
39109
vector.transfer_write %vec, %arg[] : vector<i8>, memref<i8>
40110
return
41111
}
42112

43-
// CHECK-LABEL: func @transfer_write_0d
44-
// CHECK-SAME: %[[ARG:.+]]: memref<i8>
45-
// CHECK-SAME: %[[VEC:.+]]: vector<i8>
46-
// CHECK: vector.transfer_write %[[VEC]], %[[ARG]][] : vector<i8>, memref<i8>
47-
// CHECK: return
113+
// CHECK-NOT: memref.collapse_shape
114+
// CHECK-NOT: vector.shape_cast
48115

49116
// -----
50117

@@ -54,11 +121,8 @@ func.func @transfer_read_0d(%arg : memref<i8>) -> vector<i8> {
54121
return %0 : vector<i8>
55122
}
56123

57-
// CHECK-LABEL: func @transfer_read_0d
58-
// CHECK-SAME: %[[ARG:.+]]: memref<i8>
59-
// CHECK: %[[CST:.+]] = arith.constant 0 : i8
60-
// CHECK: %[[READ:.+]] = vector.transfer_read %[[ARG]][], %[[CST]] : memref<i8>
61-
// CHECK: return %[[READ]]
124+
// CHECK-NOT: memref.collapse_shape
125+
// CHECK-NOT: vector.shape_cast
62126

63127
// -----
64128

0 commit comments

Comments
 (0)