@@ -2896,10 +2896,17 @@ class InsertOpConstantFolder final : public OpRewritePattern<InsertOp> {
2896
2896
linearize (completePositions, computeStrides (destTy.getShape ()));
2897
2897
2898
2898
SmallVector<Attribute> insertedValues;
2899
- if (auto denseSource = llvm::dyn_cast<DenseElementsAttr>(sourceCst))
2900
- llvm::append_range (insertedValues, denseSource.getValues <Attribute>());
2901
- else
2902
- insertedValues.push_back (sourceCst);
2899
+ Type destEltType = destTy.getElementType ();
2900
+
2901
+ // The `convertIntegerAttr` method specifically handles the case
2902
+ // for `llvm.mlir.constant` which can hold an attribute with a
2903
+ // different type than the return type.
2904
+ if (auto denseSource = llvm::dyn_cast<DenseElementsAttr>(sourceCst)) {
2905
+ for (auto value : denseSource.getValues <Attribute>())
2906
+ insertedValues.push_back (convertIntegerAttr (value, destEltType));
2907
+ } else {
2908
+ insertedValues.push_back (convertIntegerAttr (sourceCst, destEltType));
2909
+ }
2903
2910
2904
2911
auto allValues = llvm::to_vector (denseDest.getValues <Attribute>());
2905
2912
copy (insertedValues, allValues.begin () + insertBeginPosition);
@@ -2908,6 +2915,17 @@ class InsertOpConstantFolder final : public OpRewritePattern<InsertOp> {
2908
2915
rewriter.replaceOpWithNewOp <arith::ConstantOp>(op, newAttr);
2909
2916
return success ();
2910
2917
}
2918
+
2919
+ private:
2920
+ // / Converts the expected type to an IntegerAttr if there's
2921
+ // / a mismatch.
2922
+ Attribute convertIntegerAttr (Attribute attr, Type expectedType) const {
2923
+ if (auto intAttr = attr.dyn_cast <IntegerAttr>()) {
2924
+ if (intAttr.getType () != expectedType)
2925
+ return IntegerAttr::get (expectedType, intAttr.getInt ());
2926
+ }
2927
+ return attr;
2928
+ }
2911
2929
};
2912
2930
2913
2931
} // namespace
0 commit comments