Skip to content

Commit 0de9b8c

Browse files
committed
[MLIR][Vector] Update Transfer{Read|Write}DropUnitDimsPattern patterns
Updates `TransferWriteDropUnitDimsPattern` and `TransferReadDropUnitDimsPattern` to inherit from `MaskableOpRewritePattern` so that masked versions of xfer_read/xfer_write Ops are also supported: ```mlir %v = vector.mask %mask { vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst : memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>, vector<3x2xi8> } : vector<3x2xi1> -> vector<3x2xi8> ```
1 parent e13f1d1 commit 0de9b8c

File tree

2 files changed

+89
-18
lines changed

2 files changed

+89
-18
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp

Lines changed: 42 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -354,11 +354,13 @@ namespace {
354354
/// inserting a memref.subview dropping those unit dims. The vector shapes are
355355
/// also reduced accordingly.
356356
class TransferReadDropUnitDimsPattern
357-
: public OpRewritePattern<vector::TransferReadOp> {
358-
using OpRewritePattern::OpRewritePattern;
357+
: public vector::MaskableOpRewritePattern<vector::TransferReadOp> {
358+
using MaskableOpRewritePattern::MaskableOpRewritePattern;
359359

360-
LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
361-
PatternRewriter &rewriter) const override {
360+
FailureOr<Value>
361+
matchAndRewriteMaskableOp(vector::TransferReadOp transferReadOp,
362+
vector::MaskingOpInterface maskingOp,
363+
PatternRewriter &rewriter) const override {
362364
auto loc = transferReadOp.getLoc();
363365
Value vector = transferReadOp.getVector();
364366
VectorType vectorType = cast<VectorType>(vector.getType());
@@ -406,27 +408,37 @@ class TransferReadDropUnitDimsPattern
406408
SmallVector<Value> zeros(reducedRank, c0);
407409
auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
408410
SmallVector<bool> inBounds(reducedVectorType.getRank(), true);
409-
auto newTransferReadOp = rewriter.create<vector::TransferReadOp>(
411+
Operation *newTransferReadOp = rewriter.create<vector::TransferReadOp>(
410412
loc, reducedVectorType, reducedShapeSource, zeros, identityMap,
411413
transferReadOp.getPadding(), maskOp,
412414
rewriter.getBoolArrayAttr(inBounds));
415+
416+
if (maskingOp) {
417+
auto shapeCastMask = rewriter.createOrFold<vector::ShapeCastOp>(
418+
loc, reducedVectorType.cloneWith({}, rewriter.getI1Type()),
419+
maskingOp.getMask());
420+
newTransferReadOp = mlir::vector::maskOperation(
421+
rewriter, newTransferReadOp, shapeCastMask);
422+
}
423+
413424
auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
414-
loc, vectorType, newTransferReadOp);
415-
rewriter.replaceOp(transferReadOp, shapeCast);
425+
loc, vectorType, newTransferReadOp->getResults()[0]);
416426

417-
return success();
427+
return shapeCast;
418428
}
419429
};
420430

421431
/// Rewrites `vector.transfer_write` ops where the "source" (i.e. destination)
422432
/// has unit dims, by inserting a `memref.subview` dropping those unit dims. The
423433
/// vector shapes are also reduced accordingly.
424434
class TransferWriteDropUnitDimsPattern
425-
: public OpRewritePattern<vector::TransferWriteOp> {
426-
using OpRewritePattern::OpRewritePattern;
435+
: public vector::MaskableOpRewritePattern<vector::TransferWriteOp> {
436+
using MaskableOpRewritePattern::MaskableOpRewritePattern;
427437

428-
LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp,
429-
PatternRewriter &rewriter) const override {
438+
FailureOr<Value>
439+
matchAndRewriteMaskableOp(vector::TransferWriteOp transferWriteOp,
440+
vector::MaskingOpInterface maskingOp,
441+
PatternRewriter &rewriter) const override {
430442
auto loc = transferWriteOp.getLoc();
431443
Value vector = transferWriteOp.getVector();
432444
VectorType vectorType = cast<VectorType>(vector.getType());
@@ -474,13 +486,26 @@ class TransferWriteDropUnitDimsPattern
474486
SmallVector<Value> zeros(reducedRank, c0);
475487
auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
476488
SmallVector<bool> inBounds(reducedVectorType.getRank(), true);
477-
auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
489+
auto shapeCastSrc = rewriter.createOrFold<vector::ShapeCastOp>(
478490
loc, reducedVectorType, vector);
479-
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
480-
transferWriteOp, Type(), shapeCast, reducedShapeSource, zeros,
481-
identityMap, maskOp, rewriter.getBoolArrayAttr(inBounds));
491+
Operation *newXferWrite = rewriter.create<vector::TransferWriteOp>(
492+
loc, Type(), shapeCastSrc, reducedShapeSource, zeros, identityMap,
493+
maskOp, rewriter.getBoolArrayAttr(inBounds));
494+
495+
if (maskingOp) {
496+
auto shapeCastMask = rewriter.createOrFold<vector::ShapeCastOp>(
497+
loc, reducedVectorType.cloneWith({}, rewriter.getI1Type()),
498+
maskingOp.getMask());
499+
newXferWrite =
500+
mlir::vector::maskOperation(rewriter, newXferWrite, shapeCastMask);
501+
}
482502

483-
return success();
503+
if (transferWriteOp.hasPureTensorSemantics())
504+
return newXferWrite->getResults()[0];
505+
506+
// With Memref semantics, there's no return value. Use empty value to signal
507+
// success.
508+
return Value();
484509
}
485510
};
486511

mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
// RUN: mlir-opt %s --transform-interpreter | FileCheck %s
22

3+
//-----------------------------------------------------------------------------
4+
// [Patterns: TransferWriteDropUnitDimsPattern, TransferReadeDropUnitDimsPattern]
5+
//-----------------------------------------------------------------------------
6+
37
func.func @transfer_read_rank_reducing(
48
%arg : memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>) -> vector<3x2xi8> {
59
%c0 = arith.constant 0 : index
@@ -14,7 +18,29 @@ func.func @transfer_read_rank_reducing(
1418
// CHECK-SAME: memref<1x1x3x2xi8, {{.*}}> to memref<3x2xi8, {{.*}}>
1519
// CHECK: vector.transfer_read %[[SUBVIEW]]
1620

17-
func.func @transfer_write_rank_reducing(%arg : memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>, %vec : vector<3x2xi8>) {
21+
func.func @transfer_read_rank_reducing_masked(
22+
%arg : memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>,
23+
%mask: vector<3x2xi1>) -> vector<3x2xi8> {
24+
%c0 = arith.constant 0 : index
25+
%cst = arith.constant 0 : i8
26+
%v = vector.mask %mask {
27+
vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
28+
memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>, vector<3x2xi8>
29+
} : vector<3x2xi1> -> vector<3x2xi8>
30+
return %v : vector<3x2xi8>
31+
}
32+
// CHECK-LABEL: func @transfer_read_rank_reducing_masked
33+
// CHECK-SAME: %[[ARG:.+]]: memref<1x1x3x2xi8
34+
// CHECK-SAME: %[[MASK:.+]]: vector<3x2xi1>
35+
// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0] [1, 1, 3, 2] [1, 1, 1, 1]
36+
// CHECK-SAME: memref<1x1x3x2xi8, {{.*}}> to memref<3x2xi8, {{.*}}>
37+
// CHECK: vector.mask %[[MASK]]
38+
// CHECK-SAME: vector.transfer_read %[[SUBVIEW]]
39+
40+
func.func @transfer_write_rank_reducing(
41+
%arg : memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>,
42+
%vec : vector<3x2xi8>) {
43+
1844
%c0 = arith.constant 0 : index
1945
vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] :
2046
vector<3x2xi8>, memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>
@@ -26,6 +52,26 @@ func.func @transfer_write_rank_reducing(%arg : memref<1x1x3x2xi8, strided<[6, 6,
2652
// CHECK-SAME: memref<1x1x3x2xi8, {{.*}}> to memref<3x2xi8, {{.*}}>
2753
// CHECK: vector.transfer_write %{{.*}}, %[[SUBVIEW]]
2854

55+
func.func @transfer_write_rank_reducing_masked(
56+
%arg : memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>,
57+
%vec : vector<3x2xi8>,
58+
%mask: vector<3x2xi1>) {
59+
%c0 = arith.constant 0 : index
60+
vector.mask %mask {
61+
vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] :
62+
vector<3x2xi8>, memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>
63+
} : vector<3x2xi1>
64+
return
65+
}
66+
// CHECK-LABEL: func @transfer_write_rank_reducing_masked
67+
// CHECK-SAME: %[[ARG:.+]]: memref<1x1x3x2xi8
68+
// CHECK-SAME: %[[VEC:.+]]: vector<3x2xi8>
69+
// CHECK-SAME: %[[MASK:.+]]: vector<3x2xi1>
70+
// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0] [1, 1, 3, 2] [1, 1, 1, 1]
71+
// CHECK-SAME: memref<1x1x3x2xi8, {{.*}}> to memref<3x2xi8, {{.*}}>
72+
// CHECK: vector.mask %[[MASK]]
73+
// CHECK-SAME: vector.transfer_write %{{.*}}, %[[SUBVIEW]]
74+
2975
func.func @transfer_read_and_vector_rank_reducing(
3076
%arg : memref<1x1x3x2x1xf32>) -> vector<3x2x1xf32> {
3177
%c0 = arith.constant 0 : index

0 commit comments

Comments
 (0)