Skip to content

Commit dc55d31

Browse files
CoTinkerjinzhi129
andauthored
[mlir][tensor] Fix a crash in ExtractOp::fold (llvm#115001)
This PR fixes a crash when the tensor of `tensor.extract` is a dense resource elements attribute. Fixes llvm#114728. Co-authored-by: jinzhi <[email protected]>
1 parent c96a85a commit dc55d31

File tree

2 files changed

+22
-3
lines changed

2 files changed

+22
-3
lines changed

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1128,12 +1128,17 @@ LogicalResult ExtractOp::verify() {
11281128
}
11291129

11301130
OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
1131-
// If this is a splat elements attribute, simply return the value. All of
1132-
// the elements of a splat attribute are the same.
1133-
if (Attribute tensor = adaptor.getTensor())
1131+
if (Attribute tensor = adaptor.getTensor()) {
1132+
// If this is a splat elements attribute, simply return the value.
1133+
// All of the elements of a splat attribute are the same.
11341134
if (auto splatTensor = llvm::dyn_cast<SplatElementsAttr>(tensor))
11351135
return splatTensor.getSplatValue<Attribute>();
11361136

1137+
// If this is a dense resource elements attribute, return.
1138+
if (isa<DenseResourceElementsAttr>(tensor))
1139+
return {};
1140+
}
1141+
11371142
// Collect the constant indices into the tensor.
11381143
SmallVector<uint64_t, 8> indices;
11391144
for (Attribute indice : adaptor.getIndices()) {

mlir/test/Dialect/Tensor/canonicalize.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,20 @@ func.func @fold_extract(%arg0 : index) -> (f32, f16, f16, i32, complex<f32>) {
173173

174174
// -----
175175

176+
// Ensure extract dense resource elements not crash.
177+
178+
// CHECK-LABEL: func @extract_dense_resource_nofold
179+
func.func @extract_dense_resource_nofold() -> i64 {
180+
// CHECK: %[[EXT:.+]] = tensor.extract
181+
// CHECK-NEXT: return %[[EXT]]
182+
%c0 = arith.constant 0 : index
183+
%cst = arith.constant dense_resource<__elided__> : tensor<1xi64>
184+
%extracted = tensor.extract %cst[%c0] : tensor<1xi64>
185+
return %extracted : i64
186+
}
187+
188+
// -----
189+
176190
// CHECK-LABEL: func @fold_insert
177191
func.func @fold_insert(%arg0 : index) -> (tensor<4xf32>) {
178192
// Fold an insert into a splat.

0 commit comments

Comments
 (0)