@@ -429,20 +429,24 @@ namespace {
429
429
// / result type.
430
430
// / - The permutation map doesn't perform permutation (broadcasting is allowed).
431
431
struct TransferReadToVectorLoadLowering
432
- : public OpRewritePattern <vector::TransferReadOp> {
432
+ : public MaskableOpRewritePattern <vector::TransferReadOp> {
433
433
TransferReadToVectorLoadLowering (MLIRContext *context,
434
434
std::optional<unsigned > maxRank,
435
435
PatternBenefit benefit = 1 )
436
- : OpRewritePattern <vector::TransferReadOp>(context, benefit),
436
+ : MaskableOpRewritePattern <vector::TransferReadOp>(context, benefit),
437
437
maxTransferRank (maxRank) {}
438
438
439
- LogicalResult matchAndRewrite (vector::TransferReadOp read,
440
- PatternRewriter &rewriter) const override {
439
+ FailureOr<mlir::Value>
440
+ matchAndRewriteMaskableOp (vector::TransferReadOp read,
441
+ MaskingOpInterface maskOp,
442
+ PatternRewriter &rewriter) const override {
441
443
if (maxTransferRank && read.getVectorType ().getRank () > *maxTransferRank) {
442
444
return rewriter.notifyMatchFailure (
443
445
read, " vector type is greater than max transfer rank" );
444
446
}
445
447
448
+ if (maskOp)
449
+ return rewriter.notifyMatchFailure (read, " Masked case not supported" );
446
450
SmallVector<unsigned > broadcastedDims;
447
451
// Permutations are handled by VectorToSCF or
448
452
// populateVectorTransferPermutationMapLoweringPatterns.
@@ -485,7 +489,7 @@ struct TransferReadToVectorLoadLowering
485
489
return rewriter.notifyMatchFailure (read, " out-of-bounds needs mask" );
486
490
487
491
// Create vector load op.
488
- Operation *loadOp ;
492
+ Operation *res ;
489
493
if (read.getMask ()) {
490
494
if (read.getVectorType ().getRank () != 1 )
491
495
// vector.maskedload operates on 1-D vectors.
@@ -495,24 +499,20 @@ struct TransferReadToVectorLoadLowering
495
499
496
500
Value fill = rewriter.create <vector::SplatOp>(
497
501
read.getLoc (), unbroadcastedVectorType, read.getPadding ());
498
- loadOp = rewriter.create <vector::MaskedLoadOp>(
502
+ res = rewriter.create <vector::MaskedLoadOp>(
499
503
read.getLoc (), unbroadcastedVectorType, read.getSource (),
500
504
read.getIndices (), read.getMask (), fill);
501
505
} else {
502
- loadOp = rewriter.create <vector::LoadOp>(
506
+ res = rewriter.create <vector::LoadOp>(
503
507
read.getLoc (), unbroadcastedVectorType, read.getSource (),
504
508
read.getIndices ());
505
509
}
506
510
507
511
// Insert a broadcasting op if required.
508
- if (!broadcastedDims.empty ()) {
509
- rewriter.replaceOpWithNewOp <vector::BroadcastOp>(
510
- read, read.getVectorType (), loadOp->getResult (0 ));
511
- } else {
512
- rewriter.replaceOp (read, loadOp->getResult (0 ));
513
- }
514
-
515
- return success ();
512
+ if (!broadcastedDims.empty ())
513
+ res = rewriter.create <vector::BroadcastOp>(
514
+ read.getLoc (), read.getVectorType (), res->getResult (0 ));
515
+ return res->getResult (0 );
516
516
}
517
517
518
518
std::optional<unsigned > maxTransferRank;
@@ -581,19 +581,23 @@ struct VectorStoreToMemrefStoreLowering
581
581
// / - The permutation map is the minor identity map (neither permutation nor
582
582
// / broadcasting is allowed).
583
583
struct TransferWriteToVectorStoreLowering
584
- : public OpRewritePattern <vector::TransferWriteOp> {
584
+ : public MaskableOpRewritePattern <vector::TransferWriteOp> {
585
585
TransferWriteToVectorStoreLowering (MLIRContext *context,
586
586
std::optional<unsigned > maxRank,
587
587
PatternBenefit benefit = 1 )
588
- : OpRewritePattern <vector::TransferWriteOp>(context, benefit),
588
+ : MaskableOpRewritePattern <vector::TransferWriteOp>(context, benefit),
589
589
maxTransferRank (maxRank) {}
590
590
591
- LogicalResult matchAndRewrite (vector::TransferWriteOp write,
592
- PatternRewriter &rewriter) const override {
591
+ FailureOr<mlir::Value>
592
+ matchAndRewriteMaskableOp (vector::TransferWriteOp write,
593
+ MaskingOpInterface maskOp,
594
+ PatternRewriter &rewriter) const override {
593
595
if (maxTransferRank && write.getVectorType ().getRank () > *maxTransferRank) {
594
596
return rewriter.notifyMatchFailure (
595
597
write, " vector type is greater than max transfer rank" );
596
598
}
599
+ if (maskOp)
600
+ return rewriter.notifyMatchFailure (write, " Masked case not supported" );
597
601
598
602
// Permutations are handled by VectorToSCF or
599
603
// populateVectorTransferPermutationMapLoweringPatterns.
@@ -645,14 +649,16 @@ struct TransferWriteToVectorStoreLowering
645
649
<< write;
646
650
});
647
651
648
- rewriter.replaceOpWithNewOp <vector::MaskedStoreOp>(
649
- write, write. getSource (), write.getIndices (), write.getMask (),
650
- write.getVector ());
652
+ rewriter.create <vector::MaskedStoreOp>(
653
+ write. getLoc (), write.getSource (), write.getIndices (),
654
+ write.getMask (), write. getVector ());
651
655
} else {
652
- rewriter.replaceOpWithNewOp <vector::StoreOp>(
653
- write, write. getVector (), write.getSource (), write.getIndices ());
656
+ rewriter.create <vector::StoreOp>(write. getLoc (), write. getVector (),
657
+ write.getSource (), write.getIndices ());
654
658
}
655
- return success ();
659
+ // There's no return value for StoreOps. Use Value() to signal success to
660
+ // matchAndRewrite.
661
+ return Value ();
656
662
}
657
663
658
664
std::optional<unsigned > maxTransferRank;
0 commit comments