Skip to content

Commit 1c4b04c

Browse files
authored
[mlir] Fix crash in InsertOpConstantFolder when vector.insert operand is from a llvm.mlir.constant op (#88314)
In cases where llvm.mlir.constant has an attribute with a different type than the returned type, the folder use to create an incorrect DenseElementsAttr and crash. Resolves #74236
1 parent 9ade4e2 commit 1c4b04c

File tree

2 files changed

+35
-4
lines changed

2 files changed

+35
-4
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2896,10 +2896,17 @@ class InsertOpConstantFolder final : public OpRewritePattern<InsertOp> {
28962896
linearize(completePositions, computeStrides(destTy.getShape()));
28972897

28982898
SmallVector<Attribute> insertedValues;
2899-
if (auto denseSource = llvm::dyn_cast<DenseElementsAttr>(sourceCst))
2900-
llvm::append_range(insertedValues, denseSource.getValues<Attribute>());
2901-
else
2902-
insertedValues.push_back(sourceCst);
2899+
Type destEltType = destTy.getElementType();
2900+
2901+
// The `convertIntegerAttr` method specifically handles the case
2902+
// for `llvm.mlir.constant` which can hold an attribute with a
2903+
// different type than the return type.
2904+
if (auto denseSource = llvm::dyn_cast<DenseElementsAttr>(sourceCst)) {
2905+
for (auto value : denseSource.getValues<Attribute>())
2906+
insertedValues.push_back(convertIntegerAttr(value, destEltType));
2907+
} else {
2908+
insertedValues.push_back(convertIntegerAttr(sourceCst, destEltType));
2909+
}
29032910

29042911
auto allValues = llvm::to_vector(denseDest.getValues<Attribute>());
29052912
copy(insertedValues, allValues.begin() + insertBeginPosition);
@@ -2908,6 +2915,17 @@ class InsertOpConstantFolder final : public OpRewritePattern<InsertOp> {
29082915
rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newAttr);
29092916
return success();
29102917
}
2918+
2919+
private:
2920+
/// Converts the expected type to an IntegerAttr if there's
2921+
/// a mismatch.
2922+
Attribute convertIntegerAttr(Attribute attr, Type expectedType) const {
2923+
if (auto intAttr = attr.dyn_cast<IntegerAttr>()) {
2924+
if (intAttr.getType() != expectedType)
2925+
return IntegerAttr::get(expectedType, intAttr.getInt());
2926+
}
2927+
return attr;
2928+
}
29112929
};
29122930

29132931
} // namespace

mlir/test/Dialect/LLVMIR/constant-folding.mlir

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,3 +169,16 @@ llvm.func @null_pointer_select(%cond: i1) -> !llvm.ptr {
169169
// CHECK-NEXT: llvm.return %[[NULLPTR]]
170170
llvm.return %result : !llvm.ptr
171171
}
172+
173+
// -----
174+
175+
llvm.func @malloc(i64) -> !llvm.ptr
176+
177+
// CHECK-LABEL: func.func @insert_op
178+
func.func @insert_op(%arg0: index, %arg1: memref<13x13xi64>, %arg2: index) {
179+
%cst_7 = arith.constant dense<1526248407> : vector<1xi64>
180+
%1 = llvm.mlir.constant(1 : index) : i64
181+
%101 = vector.insert %1, %cst_7 [0] : i64 into vector<1xi64>
182+
vector.print %101 : vector<1xi64>
183+
return
184+
}

0 commit comments

Comments
 (0)