Skip to content

Commit a53cd03

Browse files
committed
[mlir][Vector] Extend xfer drop unit dim patterns
This patch extends the transfer drop unit dim patterns to support cases where the vector shape should also be reduced (e.g., transfer_read(memref<1x4x1xf32>, vector<1x4x1xf32>) -> transfer_read(memref<4xf32>, vector<4xf32>). Reviewed By: hanchung, pzread Differential Revision: https://reviews.llvm.org/D151007
1 parent 7f033d0 commit a53cd03

File tree

2 files changed

+151
-15
lines changed

2 files changed

+151
-15
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp

Lines changed: 52 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ class TransferOptimization {
6363
std::vector<Operation *> opToErase;
6464
};
6565

66+
} // namespace
6667
/// Return true if there is a path from start operation to dest operation,
6768
/// otherwise return false. The operations have to be in the same region.
6869
bool TransferOptimization::isReachable(Operation *start, Operation *dest) {
@@ -288,14 +289,25 @@ static int getReducedRank(ArrayRef<int64_t> shape) {
288289
return llvm::count_if(shape, [](int64_t dimSize) { return dimSize != 1; });
289290
}
290291

292+
/// Returns a copy of `shape` without unit dims.
293+
static SmallVector<int64_t> getReducedShape(ArrayRef<int64_t> shape) {
294+
SmallVector<int64_t> reducedShape;
295+
llvm::copy_if(shape, std::back_inserter(reducedShape),
296+
[](int64_t dimSize) { return dimSize != 1; });
297+
return reducedShape;
298+
}
299+
291300
/// Returns true if all values are `arith.constant 0 : index`
292301
static bool isZero(Value v) {
293302
auto cst = v.getDefiningOp<arith::ConstantIndexOp>();
294303
return cst && cst.value() == 0;
295304
}
296305

297-
/// Rewrites vector.transfer_read ops where the source has unit dims, by
298-
/// inserting a memref.subview dropping those unit dims.
306+
namespace {
307+
308+
/// Rewrites `vector.transfer_read` ops where the source has unit dims, by
309+
/// inserting a memref.subview dropping those unit dims. The vector shapes are
310+
/// also reduced accordingly.
299311
class TransferReadDropUnitDimsPattern
300312
: public OpRewritePattern<vector::TransferReadOp> {
301313
using OpRewritePattern::OpRewritePattern;
@@ -317,12 +329,15 @@ class TransferReadDropUnitDimsPattern
317329
return failure();
318330
if (!transferReadOp.getPermutationMap().isMinorIdentity())
319331
return failure();
332+
// Check if the source shape can be further reduced.
320333
int reducedRank = getReducedRank(sourceType.getShape());
321334
if (reducedRank == sourceType.getRank())
322-
return failure(); // The source shape can't be further reduced.
323-
if (reducedRank != vectorType.getRank())
324-
return failure(); // This pattern requires the vector shape to match the
325-
// reduced source shape.
335+
return failure();
336+
// Check if the reduced vector shape matches the reduced source shape.
337+
// Otherwise, this case is not supported yet.
338+
int vectorReducedRank = getReducedRank(vectorType.getShape());
339+
if (reducedRank != vectorReducedRank)
340+
return failure();
326341
if (llvm::any_of(transferReadOp.getIndices(),
327342
[](Value v) { return !isZero(v); }))
328343
return failure();
@@ -331,14 +346,22 @@ class TransferReadDropUnitDimsPattern
331346
Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
332347
SmallVector<Value> zeros(reducedRank, c0);
333348
auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
334-
rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
335-
transferReadOp, vectorType, reducedShapeSource, zeros, identityMap);
349+
auto reducedVectorType = VectorType::get(
350+
getReducedShape(vectorType.getShape()), vectorType.getElementType());
351+
352+
auto newTransferReadOp = rewriter.create<vector::TransferReadOp>(
353+
loc, reducedVectorType, reducedShapeSource, zeros, identityMap);
354+
auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
355+
loc, vectorType, newTransferReadOp);
356+
rewriter.replaceOp(transferReadOp, shapeCast);
357+
336358
return success();
337359
}
338360
};
339361

340-
/// Rewrites vector.transfer_write ops where the "source" (i.e. destination) has
341-
/// unit dims, by inserting a memref.subview dropping those unit dims.
362+
/// Rewrites `vector.transfer_write` ops where the "source" (i.e. destination)
363+
/// has unit dims, by inserting a `memref.subview` dropping those unit dims. The
364+
/// vector shapes are also reduced accordingly.
342365
class TransferWriteDropUnitDimsPattern
343366
: public OpRewritePattern<vector::TransferWriteOp> {
344367
using OpRewritePattern::OpRewritePattern;
@@ -360,12 +383,15 @@ class TransferWriteDropUnitDimsPattern
360383
return failure();
361384
if (!transferWriteOp.getPermutationMap().isMinorIdentity())
362385
return failure();
386+
// Check if the destination shape can be further reduced.
363387
int reducedRank = getReducedRank(sourceType.getShape());
364388
if (reducedRank == sourceType.getRank())
365-
return failure(); // The source shape can't be further reduced.
366-
if (reducedRank != vectorType.getRank())
367-
return failure(); // This pattern requires the vector shape to match the
368-
// reduced source shape.
389+
return failure();
390+
// Check if the reduced vector shape matches the reduced destination shape.
391+
// Otherwise, this case is not supported yet.
392+
int vectorReducedRank = getReducedRank(vectorType.getShape());
393+
if (reducedRank != vectorReducedRank)
394+
return failure();
369395
if (llvm::any_of(transferWriteOp.getIndices(),
370396
[](Value v) { return !isZero(v); }))
371397
return failure();
@@ -374,12 +400,20 @@ class TransferWriteDropUnitDimsPattern
374400
Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
375401
SmallVector<Value> zeros(reducedRank, c0);
376402
auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
403+
VectorType reducedVectorType = VectorType::get(
404+
getReducedShape(vectorType.getShape()), vectorType.getElementType());
405+
406+
auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
407+
loc, reducedVectorType, vector);
377408
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
378-
transferWriteOp, vector, reducedShapeSource, zeros, identityMap);
409+
transferWriteOp, shapeCast, reducedShapeSource, zeros, identityMap);
410+
379411
return success();
380412
}
381413
};
382414

415+
} // namespace
416+
383417
/// Return true if the memref type has its inner dimension matching the given
384418
/// shape. Otherwise return false.
385419
static int64_t hasMatchingInnerContigousShape(MemRefType memrefType,
@@ -439,6 +473,8 @@ checkAndCollapseInnerZeroIndices(ValueRange indices, int64_t firstDimToCollapse,
439473
return success();
440474
}
441475

476+
namespace {
477+
442478
/// Rewrites contiguous row-major vector.transfer_read ops by inserting
443479
/// memref.collapse_shape on the source so that the resulting
444480
/// vector.transfer_read has a 1D source. Requires the source shape to be
@@ -732,6 +768,7 @@ class RewriteScalarWrite : public OpRewritePattern<vector::TransferWriteOp> {
732768
return success();
733769
}
734770
};
771+
735772
} // namespace
736773

737774
void mlir::vector::transferOpflowOpt(RewriterBase &rewriter,

mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,14 @@ func.func @transfer_read_rank_reducing(
1515
// CHECK-SAME: memref<1x1x3x2xi8, {{.*}}> to memref<3x2xi8, {{.*}}>
1616
// CHECK: vector.transfer_read %[[SUBVIEW]]
1717

18+
transform.sequence failures(propagate) {
19+
^bb1(%module_op: !pdl.operation):
20+
transform.vector.apply_rank_reducing_subview_patterns %module_op
21+
: (!pdl.operation) -> !pdl.operation
22+
}
23+
24+
// -----
25+
1826
func.func @transfer_write_rank_reducing(%arg : memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>, %vec : vector<3x2xi8>) {
1927
%c0 = arith.constant 0 : index
2028
vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] :
@@ -28,6 +36,97 @@ func.func @transfer_write_rank_reducing(%arg : memref<1x1x3x2xi8, strided<[6, 6,
2836
// CHECK-SAME: memref<1x1x3x2xi8, {{.*}}> to memref<3x2xi8, {{.*}}>
2937
// CHECK: vector.transfer_write %{{.*}}, %[[SUBVIEW]]
3038

39+
transform.sequence failures(propagate) {
40+
^bb1(%module_op: !pdl.operation):
41+
transform.vector.apply_rank_reducing_subview_patterns %module_op
42+
: (!pdl.operation) -> !pdl.operation
43+
}
44+
45+
// -----
46+
47+
func.func @transfer_read_and_vector_rank_reducing(
48+
%arg : memref<1x1x3x2x1xf32>) -> vector<3x2x1xf32> {
49+
%c0 = arith.constant 0 : index
50+
%cst = arith.constant 0.0 : f32
51+
%v = vector.transfer_read %arg[%c0, %c0, %c0, %c0, %c0], %cst :
52+
memref<1x1x3x2x1xf32>, vector<3x2x1xf32>
53+
return %v : vector<3x2x1xf32>
54+
}
55+
56+
// CHECK-LABEL: func @transfer_read_and_vector_rank_reducing
57+
// CHECK-SAME: %[[ARG:.+]]: memref<1x1x3x2x1xf32>
58+
// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0, 0] [1, 1, 3, 2, 1] [1, 1, 1, 1, 1]
59+
// CHECK-SAME: memref<1x1x3x2x1xf32> to memref<3x2xf32>
60+
// CHECK: vector.transfer_read %[[SUBVIEW]]{{.*}} {in_bounds = [true, true]} : memref<3x2xf32>, vector<3x2xf32>
61+
62+
transform.sequence failures(propagate) {
63+
^bb1(%module_op: !pdl.operation):
64+
transform.vector.apply_rank_reducing_subview_patterns %module_op
65+
: (!pdl.operation) -> !pdl.operation
66+
}
67+
68+
// -----
69+
70+
func.func @transfer_write_and_vector_rank_reducing(
71+
%arg : memref<1x1x3x2x1xf32>,
72+
%vec : vector<3x2x1xf32>) {
73+
%c0 = arith.constant 0 : index
74+
vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0, %c0] :
75+
vector<3x2x1xf32>, memref<1x1x3x2x1xf32>
76+
return
77+
}
78+
79+
// CHECK-LABEL: func @transfer_write_and_vector_rank_reducing
80+
// CHECK-SAME: %[[ARG:.+]]: memref<1x1x3x2x1xf32>
81+
// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0, 0] [1, 1, 3, 2, 1] [1, 1, 1, 1, 1]
82+
// CHECK-SAME: memref<1x1x3x2x1xf32> to memref<3x2xf32>
83+
// CHECK: vector.transfer_write %{{.*}}, %[[SUBVIEW]]{{.*}} {in_bounds = [true, true]} : vector<3x2xf32>, memref<3x2xf32>
84+
85+
transform.sequence failures(propagate) {
86+
^bb1(%module_op: !transform.any_op):
87+
transform.vector.apply_rank_reducing_subview_patterns %module_op
88+
: (!transform.any_op) -> !transform.any_op
89+
}
90+
91+
// -----
92+
93+
func.func @transfer_read_and_vector_rank_reducing_to_0d(
94+
%arg : memref<1x1x1x1x1xf32>) -> vector<1x1x1xf32> {
95+
%c0 = arith.constant 0 : index
96+
%cst = arith.constant 0.0 : f32
97+
%v = vector.transfer_read %arg[%c0, %c0, %c0, %c0, %c0], %cst :
98+
memref<1x1x1x1x1xf32>, vector<1x1x1xf32>
99+
return %v : vector<1x1x1xf32>
100+
}
101+
102+
// CHECK-LABEL: func @transfer_read_and_vector_rank_reducing_to_0d
103+
// CHECK-SAME: %[[MEMREF:.+]]: memref<1x1x1x1x1xf32>
104+
// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[MEMREF]][0, 0, 0, 0, 0] [1, 1, 1, 1, 1] [1, 1, 1, 1, 1] : memref<1x1x1x1x1xf32> to memref<f32>
105+
// CHECK: %[[READ:.+]] = vector.transfer_read %[[SUBVIEW]]{{.*}} : memref<f32>, vector<f32>
106+
// CHECK: vector.shape_cast %[[READ]] : vector<f32> to vector<1x1x1xf32>
107+
108+
transform.sequence failures(propagate) {
109+
^bb1(%module_op: !pdl.operation):
110+
transform.vector.apply_rank_reducing_subview_patterns %module_op
111+
: (!pdl.operation) -> !pdl.operation
112+
}
113+
114+
// -----
115+
116+
func.func @transfer_write_and_vector_rank_reducing_to_0d(
117+
%arg : memref<1x1x1x1x1xf32>,
118+
%vec : vector<1x1x1xf32>) {
119+
%c0 = arith.constant 0 : index
120+
vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0, %c0] :
121+
vector<1x1x1xf32>, memref<1x1x1x1x1xf32>
122+
return
123+
}
124+
125+
// CHECK-LABEL: func @transfer_write_and_vector_rank_reducing_to_0d
126+
// CHECK-SAME: %[[MEMREF:.+]]: memref<1x1x1x1x1xf32>, %[[VECTOR:.+]]: vector<1x1x1xf32>
127+
// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[MEMREF]][0, 0, 0, 0, 0] [1, 1, 1, 1, 1] [1, 1, 1, 1, 1] : memref<1x1x1x1x1xf32> to memref<f32>
128+
// CHECK: %[[SHCAST:.+]] = vector.shape_cast %[[VECTOR]] : vector<1x1x1xf32> to vector<f32>
129+
// CHECK: vector.transfer_write %[[SHCAST]], %[[SUBVIEW]]{{.*}} : vector<f32>, memref<f32>
31130

32131
transform.sequence failures(propagate) {
33132
^bb1(%module_op: !transform.any_op):

0 commit comments

Comments
 (0)