Skip to content

Commit 9db3421

Browse files
[mlir][vector] Fix crash in vector.insert canonicalization
The `InsertOpConstantFolder` assumed that whenever the destination can be folded to a constant attribute, that attribute must be a `DenseElementsAttr`. That is is not necessarily the case.
1 parent b546096 commit 9db3421

File tree

2 files changed

+14
-2
lines changed

2 files changed

+14
-2
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2851,6 +2851,9 @@ class InsertOpConstantFolder final : public OpRewritePattern<InsertOp> {
28512851
Attribute vectorDestCst;
28522852
if (!matchPattern(destVector, m_Constant(&vectorDestCst)))
28532853
return failure();
2854+
auto denseDest = llvm::dyn_cast<DenseElementsAttr>(vectorDestCst);
2855+
if (!denseDest)
2856+
return failure();
28542857

28552858
VectorType destTy = destVector.getType();
28562859
if (destTy.isScalable())
@@ -2861,8 +2864,6 @@ class InsertOpConstantFolder final : public OpRewritePattern<InsertOp> {
28612864
!destVector.hasOneUse())
28622865
return failure();
28632866

2864-
auto denseDest = llvm::cast<DenseElementsAttr>(vectorDestCst);
2865-
28662867
Value sourceValue = op.getSource();
28672868
Attribute sourceCst;
28682869
if (!matchPattern(sourceValue, m_Constant(&sourceCst)))

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2729,3 +2729,14 @@ func.func @fold_vector_step_to_constant() -> vector<4xindex> {
27292729
%0 = vector.step : vector<4xindex>
27302730
return %0 : vector<4xindex>
27312731
}
2732+
2733+
// -----
2734+
2735+
// CHECK-LABEL: func @vector_insert_const_regression(
2736+
// CHECK: llvm.mlir.undef
2737+
// CHECK: vector.insert
2738+
func.func @vector_insert_const_regression(%arg0: i8) -> vector<4xi8> {
2739+
%0 = llvm.mlir.undef : vector<4xi8>
2740+
%1 = vector.insert %arg0, %0 [0] : i8 into vector<4xi8>
2741+
return %1 : vector<4xi8>
2742+
}

0 commit comments

Comments
 (0)