Skip to content

Commit 4bb9f91

Browse files
authored
[mlir][tensor] fix out-of-bound index in tensor.dim (#85901)
fix a crash when fold tensor.dim with out-of-bound index. Fixes: #70183
1 parent 46a737c commit 4bb9f91

File tree

2 files changed

+26
-0
lines changed

2 files changed

+26
-0
lines changed

mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,9 @@ struct FoldDimOfAllocTensorOp : public OpRewritePattern<tensor::DimOp> {
333333
auto allocTensorOp = dimOp.getSource().getDefiningOp<AllocTensorOp>();
334334
if (!allocTensorOp || !maybeConstantIndex)
335335
return failure();
336+
if (*maybeConstantIndex < 0 ||
337+
*maybeConstantIndex >= allocTensorOp.getType().getRank())
338+
return failure();
336339
if (!allocTensorOp.getType().isDynamicDim(*maybeConstantIndex))
337340
return failure();
338341
rewriter.replaceOp(

mlir/test/Dialect/Tensor/canonicalize.mlir

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2367,3 +2367,26 @@ func.func @dim_of_reshape_undominated(%arg0: tensor<*xf32>, %arg1: tensor<?xinde
23672367
%dim = tensor.dim %reshape, %0 : tensor<*xf32>
23682368
return %dim : index
23692369
}
2370+
2371+
// -----
2372+
2373+
// Test case: This test fails to fold because the index of tensor.dim is out_of_bounds
2374+
// CHECK-LABEL: func @dim_out_of_bounds(
2375+
// CHECK: %[[IDX:.*]] = index.constant 28
2376+
// CHECK-NEXT: bufferization.alloc_tensor
2377+
// CHECK-NEXT: %[[DIM:.*]] = tensor.dim %{{.*}}, %[[IDX]]
2378+
// CHECK-NEXT: memref.alloc
2379+
// CHECK-NEXT: memref.cast
2380+
// CHECK-NEXT: affine.vector_load %{{.*}}[{{.*}}, {{.*}}, symbol(%[[DIM]])]
2381+
// CHECK-NEXT: return
2382+
func.func @dim_out_of_bounds() -> vector<7xi32> {
2383+
%c1 = arith.constant 1 : index
2384+
%idx28 = index.constant 28
2385+
%c29 = arith.constant 29 : index
2386+
%3 = bufferization.alloc_tensor(%c29) : tensor<?xi16>
2387+
%dim = tensor.dim %3, %idx28 : tensor<?xi16>
2388+
%alloc_21 = memref.alloc(%c29) : memref<?x26x2xi32>
2389+
%16 = affine.vector_load %alloc_21[%c1, %c1, %dim] : memref<?x26x2xi32>, vector<7xi32>
2390+
return %16 : vector<7xi32>
2391+
}
2392+

0 commit comments

Comments
 (0)