Skip to content

Commit f478ae9

Browse files
nujaaAlexisPerry
authored andcommitted
[MLIR][Vector] Implement XferOp To {Load|Store}Lowering as MaskableOpRewritePattern (llvm#92892)
Implements `TransferReadToVectorLoadLowering` and `TransferWriteToVectorStoreLowering` as a `MaskableOpRewritePattern`. Allowing to exit gracefully when run on an xferOp located inside a `vector::MaskOp` instead of breaking because the pattern generated multiple ops in the MaskOp with `error: 'vector.mask' op expects only one operation to mask`. Split of llvm#90835
1 parent ac2d122 commit f478ae9

File tree

2 files changed

+48
-25
lines changed

2 files changed

+48
-25
lines changed

mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp

Lines changed: 31 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -429,20 +429,24 @@ namespace {
429429
/// result type.
430430
/// - The permutation map doesn't perform permutation (broadcasting is allowed).
431431
struct TransferReadToVectorLoadLowering
432-
: public OpRewritePattern<vector::TransferReadOp> {
432+
: public MaskableOpRewritePattern<vector::TransferReadOp> {
433433
TransferReadToVectorLoadLowering(MLIRContext *context,
434434
std::optional<unsigned> maxRank,
435435
PatternBenefit benefit = 1)
436-
: OpRewritePattern<vector::TransferReadOp>(context, benefit),
436+
: MaskableOpRewritePattern<vector::TransferReadOp>(context, benefit),
437437
maxTransferRank(maxRank) {}
438438

439-
LogicalResult matchAndRewrite(vector::TransferReadOp read,
440-
PatternRewriter &rewriter) const override {
439+
FailureOr<mlir::Value>
440+
matchAndRewriteMaskableOp(vector::TransferReadOp read,
441+
MaskingOpInterface maskOp,
442+
PatternRewriter &rewriter) const override {
441443
if (maxTransferRank && read.getVectorType().getRank() > *maxTransferRank) {
442444
return rewriter.notifyMatchFailure(
443445
read, "vector type is greater than max transfer rank");
444446
}
445447

448+
if (maskOp)
449+
return rewriter.notifyMatchFailure(read, "Masked case not supported");
446450
SmallVector<unsigned> broadcastedDims;
447451
// Permutations are handled by VectorToSCF or
448452
// populateVectorTransferPermutationMapLoweringPatterns.
@@ -485,7 +489,7 @@ struct TransferReadToVectorLoadLowering
485489
return rewriter.notifyMatchFailure(read, "out-of-bounds needs mask");
486490

487491
// Create vector load op.
488-
Operation *loadOp;
492+
Operation *res;
489493
if (read.getMask()) {
490494
if (read.getVectorType().getRank() != 1)
491495
// vector.maskedload operates on 1-D vectors.
@@ -495,24 +499,20 @@ struct TransferReadToVectorLoadLowering
495499

496500
Value fill = rewriter.create<vector::SplatOp>(
497501
read.getLoc(), unbroadcastedVectorType, read.getPadding());
498-
loadOp = rewriter.create<vector::MaskedLoadOp>(
502+
res = rewriter.create<vector::MaskedLoadOp>(
499503
read.getLoc(), unbroadcastedVectorType, read.getSource(),
500504
read.getIndices(), read.getMask(), fill);
501505
} else {
502-
loadOp = rewriter.create<vector::LoadOp>(
506+
res = rewriter.create<vector::LoadOp>(
503507
read.getLoc(), unbroadcastedVectorType, read.getSource(),
504508
read.getIndices());
505509
}
506510

507511
// Insert a broadcasting op if required.
508-
if (!broadcastedDims.empty()) {
509-
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
510-
read, read.getVectorType(), loadOp->getResult(0));
511-
} else {
512-
rewriter.replaceOp(read, loadOp->getResult(0));
513-
}
514-
515-
return success();
512+
if (!broadcastedDims.empty())
513+
res = rewriter.create<vector::BroadcastOp>(
514+
read.getLoc(), read.getVectorType(), res->getResult(0));
515+
return res->getResult(0);
516516
}
517517

518518
std::optional<unsigned> maxTransferRank;
@@ -581,19 +581,23 @@ struct VectorStoreToMemrefStoreLowering
581581
/// - The permutation map is the minor identity map (neither permutation nor
582582
/// broadcasting is allowed).
583583
struct TransferWriteToVectorStoreLowering
584-
: public OpRewritePattern<vector::TransferWriteOp> {
584+
: public MaskableOpRewritePattern<vector::TransferWriteOp> {
585585
TransferWriteToVectorStoreLowering(MLIRContext *context,
586586
std::optional<unsigned> maxRank,
587587
PatternBenefit benefit = 1)
588-
: OpRewritePattern<vector::TransferWriteOp>(context, benefit),
588+
: MaskableOpRewritePattern<vector::TransferWriteOp>(context, benefit),
589589
maxTransferRank(maxRank) {}
590590

591-
LogicalResult matchAndRewrite(vector::TransferWriteOp write,
592-
PatternRewriter &rewriter) const override {
591+
FailureOr<mlir::Value>
592+
matchAndRewriteMaskableOp(vector::TransferWriteOp write,
593+
MaskingOpInterface maskOp,
594+
PatternRewriter &rewriter) const override {
593595
if (maxTransferRank && write.getVectorType().getRank() > *maxTransferRank) {
594596
return rewriter.notifyMatchFailure(
595597
write, "vector type is greater than max transfer rank");
596598
}
599+
if (maskOp)
600+
return rewriter.notifyMatchFailure(write, "Masked case not supported");
597601

598602
// Permutations are handled by VectorToSCF or
599603
// populateVectorTransferPermutationMapLoweringPatterns.
@@ -645,14 +649,16 @@ struct TransferWriteToVectorStoreLowering
645649
<< write;
646650
});
647651

648-
rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
649-
write, write.getSource(), write.getIndices(), write.getMask(),
650-
write.getVector());
652+
rewriter.create<vector::MaskedStoreOp>(
653+
write.getLoc(), write.getSource(), write.getIndices(),
654+
write.getMask(), write.getVector());
651655
} else {
652-
rewriter.replaceOpWithNewOp<vector::StoreOp>(
653-
write, write.getVector(), write.getSource(), write.getIndices());
656+
rewriter.create<vector::StoreOp>(write.getLoc(), write.getVector(),
657+
write.getSource(), write.getIndices());
654658
}
655-
return success();
659+
// There's no return value for StoreOps. Use Value() to signal success to
660+
// matchAndRewrite.
661+
return Value();
656662
}
657663

658664
std::optional<unsigned> maxTransferRank;

mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,23 @@ func.func @transfer_to_load(%mem : memref<8x8xf32>, %i : index) -> vector<4xf32>
5151
return %res : vector<4xf32>
5252
}
5353

54+
// Masked transfer_read/write inside are NOT lowered to vector.load/store
55+
// CHECK-LABEL: func @masked_transfer_to_load(
56+
// CHECK-SAME: %[[MEM:.*]]: memref<8x8xf32>,
57+
// CHECK-SAME: %[[IDX:.*]]: index,
58+
// CHECK-SAME: %[[MASK:.*]]: vector<4xi1>) -> memref<8x8xf32>
59+
// CHECK-NOT: vector.load
60+
// CHECK-NOT: vector.store
61+
// CHECK: %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %arg0[%[[IDX]], %[[IDX]]]{{.*}} : memref<8x8xf32>, vector<4xf32> } : vector<4xi1> -> vector<4xf32>
62+
// CHECK: vector.mask %[[MASK]] { vector.transfer_write %[[READ]], %[[MEM]][%[[IDX]], %[[IDX]]]{{.*}} : vector<4xf32>, memref<8x8xf32> } : vector<4xi1>
63+
64+
func.func @masked_transfer_to_load(%mem : memref<8x8xf32>, %i : index, %mask : vector<4xi1>) -> memref<8x8xf32> {
65+
%cf0 = arith.constant 0.0 : f32
66+
%read = vector.mask %mask {vector.transfer_read %mem[%i, %i], %cf0 {in_bounds = [true]} : memref<8x8xf32>, vector<4xf32>} : vector<4xi1> -> vector<4xf32>
67+
vector.mask %mask {vector.transfer_write %read, %mem[%i, %i] {in_bounds = [true]} : vector<4xf32>, memref<8x8xf32> } : vector<4xi1>
68+
return %mem : memref<8x8xf32>
69+
}
70+
5471
// n-D results are also supported.
5572
// CHECK-LABEL: func @transfer_2D(
5673
// CHECK-SAME: %[[MEM:.*]]: memref<8x8xf32>,

0 commit comments

Comments
 (0)