@@ -5660,49 +5660,102 @@ LogicalResult CreateMaskOp::verify() {
5660
5660
5661
5661
namespace {
5662
5662
5663
- // Pattern to rewrite a CreateMaskOp with a ConstantMaskOp.
5663
+ // / Pattern to rewrite a CreateMaskOp with a ConstantMaskOp.
5664
+ // /
5665
+ // / Ex 1:
5666
+ // / %c2 = arith.constant 2 : index
5667
+ // / %c3 = arith.constant 3 : index
5668
+ // / %0 = vector.create_mask %c3, %c2 : vector<4x3xi1>
5669
+ // / Becomes:
5670
+ // / vector.constant_mask [3, 2] : vector<4x3xi1>
5671
+ // /
5672
+ // / Ex 2:
5673
+ // / %c_neg_1 = arith.constant -1 : index
5674
+ // / %0 = vector.create_mask %c_neg_1 : vector<[8]xi1>
5675
+ // / becomes:
5676
+ // / vector.constant_mask [0] : vector<[8]xi1>
5677
+ // /
5678
+ // / Ex 3:
5679
+ // / %c8 = arith.constant 8 : index
5680
+ // / %c16 = arith.constant 16 : index
5681
+ // / %0 = vector.vscale
5682
+ // / %1 = arith.muli %0, %c16 : index
5683
+ // / %10 = vector.create_mask %c8, %1 : vector<8x[16]xi1>
5684
+ // / becomes:
5685
+ // / %0 = vector.constant_mask [8, 16] : vector<8x[16]xi1>
5664
5686
class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {
5665
5687
public:
5666
5688
using OpRewritePattern::OpRewritePattern;
5667
5689
5668
5690
LogicalResult matchAndRewrite (CreateMaskOp createMaskOp,
5669
5691
PatternRewriter &rewriter) const override {
5670
- // Return if any of 'createMaskOp' operands are not defined by a constant.
5671
- auto isNotDefByConstant = [](Value operand) {
5672
- return !getConstantIntValue (operand).has_value ();
5673
- };
5674
- if (llvm::any_of (createMaskOp.getOperands (), isNotDefByConstant))
5675
- return failure ();
5692
+ VectorType retTy = createMaskOp.getResult ().getType ();
5693
+ bool isScalable = retTy.isScalable ();
5694
+
5695
+ // Check every mask operand
5696
+ for (auto [opIdx, operand] : llvm::enumerate (createMaskOp.getOperands ())) {
5697
+ if (auto cst = getConstantIntValue (operand)) {
5698
+ // Most basic case - this operand is a constant value. Note that for
5699
+ // scalable dimensions, CreateMaskOp can be folded only if the
5700
+ // corresponding operand is negative or zero.
5701
+ if (retTy.getScalableDims ()[opIdx] && *cst > 0 )
5702
+ return failure ();
5676
5703
5677
- // CreateMaskOp for scalable vectors can be folded only if all dimensions
5678
- // are negative or zero.
5679
- if (auto vType = llvm::dyn_cast<VectorType>(createMaskOp.getType ())) {
5680
- if (vType.isScalable ())
5681
- for (auto opDim : createMaskOp.getOperands ()) {
5682
- APInt intVal;
5683
- if (matchPattern (opDim, m_ConstantInt (&intVal)) &&
5684
- intVal.isStrictlyPositive ())
5685
- return failure ();
5686
- }
5704
+ continue ;
5705
+ }
5706
+
5707
+ // Non-constant operands are not allowed for non-scalable vectors.
5708
+ if (!isScalable)
5709
+ return failure ();
5710
+
5711
+ // For scalable vectors, "arith.muli %vscale, %dimSize" means an "all
5712
+ // true" mask, so can also be treated as constant.
5713
+ auto mul = operand.getDefiningOp <arith::MulIOp>();
5714
+ if (!mul)
5715
+ return failure ();
5716
+ auto mulLHS = mul.getRhs ();
5717
+ auto mulRHS = mul.getLhs ();
5718
+ bool isOneOpVscale =
5719
+ (isa<vector::VectorScaleOp>(mulLHS.getDefiningOp ()) ||
5720
+ isa<vector::VectorScaleOp>(mulRHS.getDefiningOp ()));
5721
+
5722
+ auto isConstantValMatchingDim =
5723
+ [=, dim = retTy.getShape ()[opIdx]](Value operand) {
5724
+ auto constantVal = getConstantIntValue (operand);
5725
+ return (constantVal.has_value () && constantVal.value () == dim);
5726
+ };
5727
+
5728
+ bool isOneOpConstantMatchingDim =
5729
+ isConstantValMatchingDim (mulLHS) || isConstantValMatchingDim (mulRHS);
5730
+
5731
+ if (!isOneOpVscale || !isOneOpConstantMatchingDim)
5732
+ return failure ();
5687
5733
}
5688
5734
5689
5735
// Gather constant mask dimension sizes.
5690
5736
SmallVector<int64_t , 4 > maskDimSizes;
5691
5737
maskDimSizes.reserve (createMaskOp->getNumOperands ());
5692
5738
for (auto [operand, maxDimSize] : llvm::zip_equal (
5693
5739
createMaskOp.getOperands (), createMaskOp.getType ().getShape ())) {
5694
- int64_t dimSize = getConstantIntValue (operand).value ();
5695
- dimSize = std::min (dimSize, maxDimSize);
5740
+ std::optional dimSize = getConstantIntValue (operand);
5741
+ if (!dimSize) {
5742
+ // Although not a constant, it is safe to assume that `operand` is
5743
+ // "vscale * maxDimSize".
5744
+ maskDimSizes.push_back (maxDimSize);
5745
+ continue ;
5746
+ }
5747
+ int64_t dimSizeVal = std::min (dimSize.value (), maxDimSize);
5696
5748
// If one of dim sizes is zero, set all dims to zero.
5697
5749
if (dimSize <= 0 ) {
5698
5750
maskDimSizes.assign (createMaskOp.getType ().getRank (), 0 );
5699
5751
break ;
5700
5752
}
5701
- maskDimSizes.push_back (dimSize );
5753
+ maskDimSizes.push_back (dimSizeVal );
5702
5754
}
5755
+
5703
5756
// Replace 'createMaskOp' with ConstantMaskOp.
5704
5757
rewriter.replaceOpWithNewOp <ConstantMaskOp>(
5705
- createMaskOp, createMaskOp. getResult (). getType () ,
5758
+ createMaskOp, retTy ,
5706
5759
vector::getVectorSubscriptAttr (rewriter, maskDimSizes));
5707
5760
return success ();
5708
5761
}
0 commit comments