Skip to content

Commit 0a22a80

Browse files
authored
[mlir][vector] Fix extractelement/insertelement folder crash on poison attr (#71333)
Types of incoming attributes weren't properly checked.
1 parent edea974 commit 0a22a80

File tree

2 files changed

+79
-13
lines changed

2 files changed

+79
-13
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1188,9 +1188,6 @@ OpFoldResult vector::ExtractElementOp::fold(FoldAdaptor adaptor) {
11881188
if (!adaptor.getPosition())
11891189
return {};
11901190

1191-
Attribute src = adaptor.getVector();
1192-
Attribute pos = adaptor.getPosition();
1193-
11941191
// Fold extractelement (splat X) -> X.
11951192
if (auto splat = getVector().getDefiningOp<vector::SplatOp>())
11961193
return splat.getInput();
@@ -1200,13 +1197,16 @@ OpFoldResult vector::ExtractElementOp::fold(FoldAdaptor adaptor) {
12001197
if (!llvm::isa<VectorType>(broadcast.getSource().getType()))
12011198
return broadcast.getSource();
12021199

1200+
auto src = dyn_cast_or_null<DenseElementsAttr>(adaptor.getVector());
1201+
auto pos = dyn_cast_or_null<IntegerAttr>(adaptor.getPosition());
12031202
if (!pos || !src)
12041203
return {};
12051204

1206-
auto srcElements = llvm::cast<DenseElementsAttr>(src).getValues<Attribute>();
1205+
auto srcElements = src.getValues<Attribute>();
12071206

1208-
auto attr = llvm::dyn_cast<IntegerAttr>(pos);
1209-
uint64_t posIdx = attr.getInt();
1207+
uint64_t posIdx = pos.getInt();
1208+
if (posIdx >= srcElements.size())
1209+
return {};
12101210

12111211
return srcElements[posIdx];
12121212
}
@@ -2511,18 +2511,20 @@ OpFoldResult vector::InsertElementOp::fold(FoldAdaptor adaptor) {
25112511
if (!adaptor.getPosition())
25122512
return {};
25132513

2514-
Attribute src = adaptor.getSource();
2515-
Attribute dst = adaptor.getDest();
2516-
Attribute pos = adaptor.getPosition();
2514+
auto src = dyn_cast_or_null<TypedAttr>(adaptor.getSource());
2515+
auto dst = dyn_cast_or_null<DenseElementsAttr>(adaptor.getDest());
2516+
auto pos = dyn_cast_or_null<IntegerAttr>(adaptor.getPosition());
25172517
if (!src || !dst || !pos)
25182518
return {};
25192519

2520-
auto dstElements = llvm::cast<DenseElementsAttr>(dst).getValues<Attribute>();
2520+
if (src.getType() != getDestVectorType().getElementType())
2521+
return {};
2522+
2523+
auto dstElements = dst.getValues<Attribute>();
25212524

25222525
SmallVector<Attribute> results(dstElements);
25232526

2524-
auto attr = llvm::dyn_cast<IntegerAttr>(pos);
2525-
uint64_t posIdx = attr.getInt();
2527+
uint64_t posIdx = pos.getInt();
25262528
if (posIdx >= results.size())
25272529
return {};
25282530
results[posIdx] = src;

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2027,6 +2027,46 @@ func.func @insert_element_invalid_fold() -> vector<1xf32> {
20272027
return %46 : vector<1xf32>
20282028
}
20292029

2030+
2031+
// -----
2032+
2033+
// Do not crash on poison
2034+
// CHECK-LABEL: func @insert_poison_fold1
2035+
// CHECK: vector.insertelement
2036+
func.func @insert_poison_fold1() -> vector<4xi32> {
2037+
%v = ub.poison : vector<4xi32>
2038+
%s = arith.constant 7 : i32
2039+
%i = arith.constant 2 : i32
2040+
%1 = vector.insertelement %s, %v[%i : i32] : vector<4xi32>
2041+
return %1 : vector<4xi32>
2042+
}
2043+
2044+
// -----
2045+
2046+
// Do not crash on poison
2047+
// CHECK-LABEL: func @insert_poison_fold2
2048+
// CHECK: vector.insertelement
2049+
func.func @insert_poison_fold2() -> vector<4xi32> {
2050+
%v = arith.constant dense<[0, 1, 2, 3]> : vector<4xi32>
2051+
%s = ub.poison : i32
2052+
%i = arith.constant 2 : i32
2053+
%1 = vector.insertelement %s, %v[%i : i32] : vector<4xi32>
2054+
return %1 : vector<4xi32>
2055+
}
2056+
2057+
// -----
2058+
2059+
// Do not crash on poison
2060+
// CHECK-LABEL: func @insert_poison_fold3
2061+
// CHECK: vector.insertelement
2062+
func.func @insert_poison_fold3() -> vector<4xi32> {
2063+
%v = arith.constant dense<[0, 1, 2, 3]> : vector<4xi32>
2064+
%s = arith.constant 7 : i32
2065+
%i = ub.poison : i32
2066+
%1 = vector.insertelement %s, %v[%i : i32] : vector<4xi32>
2067+
return %1 : vector<4xi32>
2068+
}
2069+
20302070
// -----
20312071

20322072
// CHECK-LABEL: func @extract_element_fold
@@ -2051,6 +2091,30 @@ func.func @extract_element_splat_fold(%a : i32) -> i32 {
20512091

20522092
// -----
20532093

2094+
// Do not crash on poison
2095+
// CHECK-LABEL: func @extract_element_poison_fold1
2096+
// CHECK: vector.extractelement
2097+
func.func @extract_element_poison_fold1() -> i32 {
2098+
%v = ub.poison : vector<4xi32>
2099+
%i = arith.constant 2 : i32
2100+
%1 = vector.extractelement %v[%i : i32] : vector<4xi32>
2101+
return %1 : i32
2102+
}
2103+
2104+
// -----
2105+
2106+
// Do not crash on poison
2107+
// CHECK-LABEL: func @extract_element_poison_fold2
2108+
// CHECK: vector.extractelement
2109+
func.func @extract_element_poison_fold2() -> i32 {
2110+
%v = arith.constant dense<[1, 3, 5, 7]> : vector<4xi32>
2111+
%i = ub.poison : i32
2112+
%1 = vector.extractelement %v[%i : i32] : vector<4xi32>
2113+
return %1 : i32
2114+
}
2115+
2116+
// -----
2117+
20542118
// CHECK-LABEL: func @reduce_one_element_vector_extract
20552119
// CHECK-SAME: (%[[V:.+]]: vector<1xf32>)
20562120
// CHECK: %[[S:.+]] = vector.extract %[[V]][0] : f32 from vector<1xf32>
@@ -2436,4 +2500,4 @@ func.func @load_store_forwarding_rank_mismatch(%v0: vector<4x1x1xf32>, %arg0: te
24362500
permutation_map = affine_map<(d0, d1, d2) -> (d1, 0, d2, 0)>} :
24372501
tensor<4x4x4xf32>, vector<1x100x4x5xf32>
24382502
return %r : vector<1x100x4x5xf32>
2439-
}
2503+
}

0 commit comments

Comments
 (0)