Skip to content

Commit 5a71f7a

Browse files
authored
[mlir] Fix bufferization.alloc_tensor canonicalization crash (#70891)
This make sure that an invalid negative dimension is ignored and stays dynamic instead of crashing the compiler. Fixes #70887
1 parent 30416f3 commit 5a71f7a

File tree

2 files changed

+18
-1
lines changed

2 files changed

+18
-1
lines changed

mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,11 @@ struct ReplaceStaticShapeDims : OpRewritePattern<AllocTensorOp> {
314314
Value value = op.getDynamicSizes()[dynValCounter++];
315315
APInt intVal;
316316
if (matchPattern(value, m_ConstantInt(&intVal))) {
317-
newShape[i] = intVal.getSExtValue();
317+
int64_t dim = intVal.getSExtValue();
318+
if (dim >= 0)
319+
newShape[i] = intVal.getSExtValue();
320+
else
321+
newDynamicSizes.push_back(value);
318322
} else {
319323
newDynamicSizes.push_back(value);
320324
}

mlir/test/Dialect/Bufferization/canonicalize.mlir

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,3 +351,16 @@ func.func @dealloc_base_memref_extract_of_alloc(%arg0: memref<2xi32>) {
351351
// CHECK-SAME:([[ARG0:%.+]]: memref<2xi32>)
352352
// CHECK-NOT: memref.alloc(
353353
// CHECK: bufferization.dealloc ([[ARG0]] : memref<2xi32>) if (%true
354+
355+
// -----
356+
357+
// CHECK-LABEL: func @negative_input
358+
func.func @negative_input() -> tensor<?x?x?xf16> {
359+
%idx27 = index.constant 27
360+
%idx-3 = index.constant -3 // negative integer?
361+
%c10 = arith.constant 10 : index
362+
// CHECK: bufferization.alloc_tensor
363+
// CHECK-SAME: tensor<10x?x27xf16>
364+
%11 = bufferization.alloc_tensor(%c10, %idx-3, %idx27) : tensor<?x?x?xf16>
365+
return %11 : tensor<?x?x?xf16>
366+
}

0 commit comments

Comments
 (0)