Skip to content

[mlir][Vector] Move vector.insert canonicalizers for DenseElementsAttr to folders #128040

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 54 additions & 65 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3019,94 +3019,78 @@ class InsertSplatToSplat final : public OpRewritePattern<InsertOp> {
}
};

// Pattern to rewrite a InsertOp(ConstantOp into ConstantOp) -> ConstantOp.
class InsertOpConstantFolder final : public OpRewritePattern<InsertOp> {
public:
using OpRewritePattern::OpRewritePattern;

// Do not create constants with more than `vectorSizeFoldThreashold` elements,
// unless the source vector constant has a single use.
static constexpr int64_t vectorSizeFoldThreshold = 256;

LogicalResult matchAndRewrite(InsertOp op,
PatternRewriter &rewriter) const override {
// TODO: Canonicalization for dynamic position not implemented yet.
if (op.hasDynamicPosition())
return failure();
} // namespace

// Return if 'InsertOp' operand is not defined by a compatible vector
// ConstantOp.
TypedValue<VectorType> destVector = op.getDest();
Attribute vectorDestCst;
if (!matchPattern(destVector, m_Constant(&vectorDestCst)))
return failure();
auto denseDest = llvm::dyn_cast<DenseElementsAttr>(vectorDestCst);
if (!denseDest)
return failure();
static Attribute
foldDenseElementsAttrDestInsertOp(InsertOp insertOp, Attribute srcAttr,
Attribute dstAttr,
int64_t maxVectorSizeFoldThreshold) {
if (insertOp.hasDynamicPosition())
return {};

VectorType destTy = destVector.getType();
if (destTy.isScalable())
return failure();
auto denseDst = llvm::dyn_cast_if_present<DenseElementsAttr>(dstAttr);
if (!denseDst)
return {};

// Make sure we do not create too many large constants.
if (destTy.getNumElements() > vectorSizeFoldThreshold &&
!destVector.hasOneUse())
return failure();
if (!srcAttr) {
return {};
}

Value sourceValue = op.getSource();
Attribute sourceCst;
if (!matchPattern(sourceValue, m_Constant(&sourceCst)))
return failure();
VectorType destTy = insertOp.getDestVectorType();
if (destTy.isScalable())
return {};

// Calculate the linearized position of the continuous chunk of elements to
// insert.
llvm::SmallVector<int64_t> completePositions(destTy.getRank(), 0);
copy(op.getStaticPosition(), completePositions.begin());
int64_t insertBeginPosition =
linearize(completePositions, computeStrides(destTy.getShape()));

SmallVector<Attribute> insertedValues;
Type destEltType = destTy.getElementType();

// The `convertIntegerAttr` method specifically handles the case
// for `llvm.mlir.constant` which can hold an attribute with a
// different type than the return type.
if (auto denseSource = llvm::dyn_cast<DenseElementsAttr>(sourceCst)) {
for (auto value : denseSource.getValues<Attribute>())
insertedValues.push_back(convertIntegerAttr(value, destEltType));
} else {
insertedValues.push_back(convertIntegerAttr(sourceCst, destEltType));
}
// Make sure we do not create too many large constants.
if (destTy.getNumElements() > maxVectorSizeFoldThreshold &&
!insertOp->hasOneUse())
return {};

auto allValues = llvm::to_vector(denseDest.getValues<Attribute>());
copy(insertedValues, allValues.begin() + insertBeginPosition);
auto newAttr = DenseElementsAttr::get(destTy, allValues);
// Calculate the linearized position of the continuous chunk of elements to
// insert.
llvm::SmallVector<int64_t> completePositions(destTy.getRank(), 0);
copy(insertOp.getStaticPosition(), completePositions.begin());
int64_t insertBeginPosition =
linearize(completePositions, computeStrides(destTy.getShape()));

rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newAttr);
return success();
}
SmallVector<Attribute> insertedValues;
Type destEltType = destTy.getElementType();

private:
/// Converts the expected type to an IntegerAttr if there's
/// a mismatch.
Attribute convertIntegerAttr(Attribute attr, Type expectedType) const {
auto convertIntegerAttr = [](Attribute attr, Type expectedType) -> Attribute {
if (auto intAttr = mlir::dyn_cast<IntegerAttr>(attr)) {
if (intAttr.getType() != expectedType)
return IntegerAttr::get(expectedType, intAttr.getInt());
}
return attr;
};

// The `convertIntegerAttr` method specifically handles the case
// for `llvm.mlir.constant` which can hold an attribute with a
// different type than the return type.
if (auto denseSource = llvm::dyn_cast<DenseElementsAttr>(srcAttr)) {
for (auto value : denseSource.getValues<Attribute>())
insertedValues.push_back(convertIntegerAttr(value, destEltType));
} else {
insertedValues.push_back(convertIntegerAttr(srcAttr, destEltType));
}
};

} // namespace
auto allValues = llvm::to_vector(denseDst.getValues<Attribute>());
copy(insertedValues, allValues.begin() + insertBeginPosition);
auto newAttr = DenseElementsAttr::get(destTy, allValues);

return newAttr;
}

void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat,
InsertOpConstantFolder>(context);
results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat>(context);
}

OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
// Do not create constants with more than `vectorSizeFoldThreashold` elements,
// unless the source vector constant has a single use.
constexpr int64_t vectorSizeFoldThreshold = 256;
// Fold "vector.insert %v, %dest [] : vector<2x2xf32> from vector<2x2xf32>" to
// %v. Note: Do not fold "vector.insert %v, %dest [] : f32 into vector<f32>"
// (type mismatch).
Expand All @@ -3118,6 +3102,11 @@ OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
if (auto res = foldPoisonIndexInsertExtractOp(
getContext(), adaptor.getStaticPosition(), kPoisonIndex))
return res;
if (auto res = foldDenseElementsAttrDestInsertOp(*this, adaptor.getSource(),
adaptor.getDest(),
vectorSizeFoldThreshold)) {
return res;
}

return {};
}
Expand Down
10 changes: 3 additions & 7 deletions mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1517,13 +1517,9 @@ func.func @constant_mask_2d() -> vector<4x4xi1> {
}

// CHECK-LABEL: func @constant_mask_2d
// CHECK: %[[VAL_0:.*]] = arith.constant dense<[true, true, false, false]> : vector<4xi1>
// CHECK: %[[VAL_1:.*]] = arith.constant dense<false> : vector<4x4xi1>
// CHECK: %[[VAL_2:.*]] = builtin.unrealized_conversion_cast %[[VAL_1]] : vector<4x4xi1> to !llvm.array<4 x vector<4xi1>>
// CHECK: %[[VAL_3:.*]] = llvm.insertvalue %[[VAL_0]], %[[VAL_2]][0] : !llvm.array<4 x vector<4xi1>>
// CHECK: %[[VAL_4:.*]] = llvm.insertvalue %[[VAL_0]], %[[VAL_3]][1] : !llvm.array<4 x vector<4xi1>>
// CHECK: %[[VAL_5:.*]] = builtin.unrealized_conversion_cast %[[VAL_4]] : !llvm.array<4 x vector<4xi1>> to vector<4x4xi1>
// CHECK: return %[[VAL_5]] : vector<4x4xi1>
// CHECK: %[[VAL_0:.*]] = arith.constant
// CHECK-SAME{LITERAL}: dense<[[true, true, false, false], [true, true, false, false], [false, false, false, false], [false, false, false, false]]> : vector<4x4xi1>
// CHECK: return %[[VAL_0]] : vector<4x4xi1>

// -----

Expand Down
17 changes: 6 additions & 11 deletions mlir/test/Dialect/Vector/vector-mask-lowering-transforms.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,19 @@ func.func @genbool_1d() -> vector<8xi1> {
}

// CHECK-LABEL: func @genbool_2d
// CHECK: %[[C1:.*]] = arith.constant dense<[true, true, false, false]> : vector<4xi1>
// CHECK: %[[C2:.*]] = arith.constant dense<false> : vector<4x4xi1>
// CHECK: %[[T0:.*]] = vector.insert %[[C1]], %[[C2]] [0] : vector<4xi1> into vector<4x4xi1>
// CHECK: %[[T1:.*]] = vector.insert %[[C1]], %[[T0]] [1] : vector<4xi1> into vector<4x4xi1>
// CHECK: return %[[T1]] : vector<4x4xi1>
// CHECK: %[[C0:.*]] = arith.constant
// CHECK-SAME{LITERAL}: dense<[[true, true, false, false], [true, true, false, false], [false, false, false, false], [false, false, false, false]]> : vector<4x4xi1>
// CHECK: return %[[C0]] : vector<4x4xi1>

func.func @genbool_2d() -> vector<4x4xi1> {
%v = vector.constant_mask [2, 2] : vector<4x4xi1>
return %v: vector<4x4xi1>
}

// CHECK-LABEL: func @genbool_3d
// CHECK-DAG: %[[C1:.*]] = arith.constant dense<[true, true, true, false]> : vector<4xi1>
// CHECK-DAG: %[[C2:.*]] = arith.constant dense<false> : vector<3x4xi1>
// CHECK-DAG: %[[C3:.*]] = arith.constant dense<false> : vector<2x3x4xi1>
// CHECK: %[[T0:.*]] = vector.insert %[[C1]], %[[C2]] [0] : vector<4xi1> into vector<3x4xi1>
// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[C3]] [0] : vector<3x4xi1> into vector<2x3x4xi1>
// CHECK: return %[[T1]] : vector<2x3x4xi1>
// CHECK: %[[C0:.*]] = arith.constant
// CHECK-SAME{LITERAL}: dense<[[[true, true, true, false], [false, false, false, false], [false, false, false, false]], [[false, false, false, false], [false, false, false, false], [false, false, false, false]]]> : vector<2x3x4xi1>
// CHECK: return %[[C0]] : vector<2x3x4xi1>

func.func @genbool_3d() -> vector<2x3x4xi1> {
%v = vector.constant_mask [1, 1, 3] : vector<2x3x4xi1>
Expand Down