Skip to content

Commit 6a4750d

Browse files
authored
[mlir] Fix crash when folding tensor.dim(tensor.collapse()) on out-of-bound dim (#119941)
Addresses one of the cases described in #119866
1 parent 65e0031 commit 6a4750d

File tree

2 files changed

+16
-1
lines changed

2 files changed

+16
-1
lines changed

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2012,7 +2012,8 @@ struct FoldDimOfCollapseShape : public OpRewritePattern<DimOp> {
20122012

20132013
// Only constant dimension values are supported.
20142014
std::optional<int64_t> dim = dimOp.getConstantIndex();
2015-
if (!dim.has_value())
2015+
if (!dim.has_value() ||
2016+
dim.value() >= collapseShapeOp.getResultType().getRank())
20162017
return failure();
20172018

20182019
// Skip static dims. These are folded to constant ops.

mlir/test/Dialect/Tensor/canonicalize.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2344,6 +2344,20 @@ func.func @dim_of_collapse_shape(%t: tensor<?x?x?x7x?xf32>) -> index {
23442344

23452345
// -----
23462346

2347+
// Can't fold when dim is out of bound.
2348+
// CHECK-LABEL: func @out_of_bound_dim_of_collapse_shape(
2349+
// CHECK: %[[DIM:.*]] = tensor.dim
2350+
// CHECK: return %[[DIM]]
2351+
func.func @out_of_bound_dim_of_collapse_shape(%t: tensor<?x?x?x7x?xf32>) -> index {
2352+
%c5 = arith.constant 5 : index
2353+
%0 = tensor.collapse_shape %t [[0], [1, 2, 3, 4]]
2354+
: tensor<?x?x?x7x?xf32> into tensor<?x?xf32>
2355+
%1 = tensor.dim %0, %c5 : tensor<?x?xf32>
2356+
return %1 : index
2357+
}
2358+
2359+
// -----
2360+
23472361
// CHECK-LABEL: func @collapse_expand_fold_to_cast(
23482362
// CHECK-SAME: %[[t:.*]]: tensor<?xf32>
23492363
// CHECK: return %[[t]]

0 commit comments

Comments
 (0)