Skip to content

Commit 354adb4

Browse files
authored
[mlir][vector] Extend CreateMaskFolder (#75842)
Extends `CreateMaskFolder` pattern so that the following: ```mlir %c8 = arith.constant 8 : index %c16 = arith.constant 16 : index %0 = vector.vscale %1 = arith.muli %0, %c16 : index %10 = vector.create_mask %c8, %1 : vector<8x[16]xi1> ``` is folded as: ```mlir %0 = vector.constant_mask [8, 16] : vector<8x[16]xi1> ```
1 parent 83f8cae commit 354adb4

File tree

2 files changed

+87
-21
lines changed

2 files changed

+87
-21
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 74 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5660,49 +5660,102 @@ LogicalResult CreateMaskOp::verify() {
56605660

56615661
namespace {
56625662

5663-
// Pattern to rewrite a CreateMaskOp with a ConstantMaskOp.
5663+
/// Pattern to rewrite a CreateMaskOp with a ConstantMaskOp.
5664+
///
5665+
/// Ex 1:
5666+
/// %c2 = arith.constant 2 : index
5667+
/// %c3 = arith.constant 3 : index
5668+
/// %0 = vector.create_mask %c3, %c2 : vector<4x3xi1>
5669+
/// Becomes:
5670+
/// vector.constant_mask [3, 2] : vector<4x3xi1>
5671+
///
5672+
/// Ex 2:
5673+
/// %c_neg_1 = arith.constant -1 : index
5674+
/// %0 = vector.create_mask %c_neg_1 : vector<[8]xi1>
5675+
/// becomes:
5676+
/// vector.constant_mask [0] : vector<[8]xi1>
5677+
///
5678+
/// Ex 3:
5679+
/// %c8 = arith.constant 8 : index
5680+
/// %c16 = arith.constant 16 : index
5681+
/// %0 = vector.vscale
5682+
/// %1 = arith.muli %0, %c16 : index
5683+
/// %10 = vector.create_mask %c8, %1 : vector<8x[16]xi1>
5684+
/// becomes:
5685+
/// %0 = vector.constant_mask [8, 16] : vector<8x[16]xi1>
56645686
class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {
56655687
public:
56665688
using OpRewritePattern::OpRewritePattern;
56675689

56685690
LogicalResult matchAndRewrite(CreateMaskOp createMaskOp,
56695691
PatternRewriter &rewriter) const override {
5670-
// Return if any of 'createMaskOp' operands are not defined by a constant.
5671-
auto isNotDefByConstant = [](Value operand) {
5672-
return !getConstantIntValue(operand).has_value();
5673-
};
5674-
if (llvm::any_of(createMaskOp.getOperands(), isNotDefByConstant))
5675-
return failure();
5692+
VectorType retTy = createMaskOp.getResult().getType();
5693+
bool isScalable = retTy.isScalable();
5694+
5695+
// Check every mask operand
5696+
for (auto [opIdx, operand] : llvm::enumerate(createMaskOp.getOperands())) {
5697+
if (auto cst = getConstantIntValue(operand)) {
5698+
// Most basic case - this operand is a constant value. Note that for
5699+
// scalable dimensions, CreateMaskOp can be folded only if the
5700+
// corresponding operand is negative or zero.
5701+
if (retTy.getScalableDims()[opIdx] && *cst > 0)
5702+
return failure();
56765703

5677-
// CreateMaskOp for scalable vectors can be folded only if all dimensions
5678-
// are negative or zero.
5679-
if (auto vType = llvm::dyn_cast<VectorType>(createMaskOp.getType())) {
5680-
if (vType.isScalable())
5681-
for (auto opDim : createMaskOp.getOperands()) {
5682-
APInt intVal;
5683-
if (matchPattern(opDim, m_ConstantInt(&intVal)) &&
5684-
intVal.isStrictlyPositive())
5685-
return failure();
5686-
}
5704+
continue;
5705+
}
5706+
5707+
// Non-constant operands are not allowed for non-scalable vectors.
5708+
if (!isScalable)
5709+
return failure();
5710+
5711+
// For scalable vectors, "arith.muli %vscale, %dimSize" means an "all
5712+
// true" mask, so can also be treated as constant.
5713+
auto mul = operand.getDefiningOp<arith::MulIOp>();
5714+
if (!mul)
5715+
return failure();
5716+
auto mulLHS = mul.getRhs();
5717+
auto mulRHS = mul.getLhs();
5718+
bool isOneOpVscale =
5719+
(isa<vector::VectorScaleOp>(mulLHS.getDefiningOp()) ||
5720+
isa<vector::VectorScaleOp>(mulRHS.getDefiningOp()));
5721+
5722+
auto isConstantValMatchingDim =
5723+
[=, dim = retTy.getShape()[opIdx]](Value operand) {
5724+
auto constantVal = getConstantIntValue(operand);
5725+
return (constantVal.has_value() && constantVal.value() == dim);
5726+
};
5727+
5728+
bool isOneOpConstantMatchingDim =
5729+
isConstantValMatchingDim(mulLHS) || isConstantValMatchingDim(mulRHS);
5730+
5731+
if (!isOneOpVscale || !isOneOpConstantMatchingDim)
5732+
return failure();
56875733
}
56885734

56895735
// Gather constant mask dimension sizes.
56905736
SmallVector<int64_t, 4> maskDimSizes;
56915737
maskDimSizes.reserve(createMaskOp->getNumOperands());
56925738
for (auto [operand, maxDimSize] : llvm::zip_equal(
56935739
createMaskOp.getOperands(), createMaskOp.getType().getShape())) {
5694-
int64_t dimSize = getConstantIntValue(operand).value();
5695-
dimSize = std::min(dimSize, maxDimSize);
5740+
std::optional dimSize = getConstantIntValue(operand);
5741+
if (!dimSize) {
5742+
// Although not a constant, it is safe to assume that `operand` is
5743+
// "vscale * maxDimSize".
5744+
maskDimSizes.push_back(maxDimSize);
5745+
continue;
5746+
}
5747+
int64_t dimSizeVal = std::min(dimSize.value(), maxDimSize);
56965748
// If one of dim sizes is zero, set all dims to zero.
56975749
if (dimSize <= 0) {
56985750
maskDimSizes.assign(createMaskOp.getType().getRank(), 0);
56995751
break;
57005752
}
5701-
maskDimSizes.push_back(dimSize);
5753+
maskDimSizes.push_back(dimSizeVal);
57025754
}
5755+
57035756
// Replace 'createMaskOp' with ConstantMaskOp.
57045757
rewriter.replaceOpWithNewOp<ConstantMaskOp>(
5705-
createMaskOp, createMaskOp.getResult().getType(),
5758+
createMaskOp, retTy,
57065759
vector::getVectorSubscriptAttr(rewriter, maskDimSizes));
57075760
return success();
57085761
}

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,19 @@ func.func @create_vector_mask_to_constant_mask_truncation_zero() -> (vector<4x3x
5353

5454
// -----
5555

56+
// CHECK-LABEL: create_vector_mask_to_constant_mask_scalable_all_true
57+
func.func @create_vector_mask_to_constant_mask_scalable_all_true() -> (vector<8x[16]xi1>) {
58+
%c8 = arith.constant 8 : index
59+
%c16 = arith.constant 16 : index
60+
%0 = vector.vscale
61+
%1 = arith.muli %0, %c16 : index
62+
// CHECK: vector.constant_mask [8, 16] : vector<8x[16]xi1>
63+
%10 = vector.create_mask %c8, %1 : vector<8x[16]xi1>
64+
return %10 : vector<8x[16]xi1>
65+
}
66+
67+
// -----
68+
5669
// CHECK-LABEL: create_mask_transpose_to_transposed_create_mask
5770
// CHECK-SAME: %[[DIM0:.*]]: index, %[[DIM1:.*]]: index, %[[DIM2:.*]]: index
5871
func.func @create_mask_transpose_to_transposed_create_mask(

0 commit comments

Comments
 (0)