Skip to content

Commit bcd14b0

Browse files
authored
[mlir][bufferization] Fix SimplifyClones with dealloc before cloneOp (#79098)
The SimplifyClones pass relies on the assumption that the deallocOp follows the cloneOp. However, a crash occurs when there is a redundantDealloc preceding the cloneOp. This PR addresses the issue by ensuring the presence of deallocOp after cloneOp. The verification is performed by checking if the loop of the sub sequent node of cloneOp reaches the tail of the list. Fix #74306
1 parent 5cc0f76 commit bcd14b0

File tree

2 files changed

+15
-0
lines changed

2 files changed

+15
-0
lines changed

mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -505,6 +505,9 @@ struct SimplifyClones : public OpRewritePattern<CloneOp> {
505505
// of the source.
506506
for (Operation *pos = cloneOp->getNextNode(); pos != redundantDealloc;
507507
pos = pos->getNextNode()) {
508+
// Bail if we run out of operations while looking for a deallocation op.
509+
if (!pos)
510+
return failure();
508511
auto effectInterface = dyn_cast<MemoryEffectOpInterface>(pos);
509512
if (!effectInterface)
510513
continue;

mlir/test/Dialect/Bufferization/canonicalize.mlir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,18 @@ func.func @clone_and_realloc(%arg0: memref<?xf32>) {
235235

236236
// -----
237237

238+
// Verify SimplifyClones skips clones with preceding deallocation.
239+
// CHECK-LABEL: @clone_and_preceding_dealloc
240+
func.func @clone_and_preceding_dealloc(%arg0: memref<?xf32>) -> memref<32xf32> {
241+
memref.dealloc %arg0 : memref<?xf32>
242+
%0 = bufferization.clone %arg0 : memref<?xf32> to memref<32xf32>
243+
return %0 : memref<32xf32>
244+
}
245+
// CHECK-SAME: %[[ARG:.*]]: memref<?xf32>
246+
// CHECK-NOT: %cast = memref.cast %[[ARG]]
247+
248+
// -----
249+
238250
// CHECK-LABEL: func @tensor_cast_to_memref
239251
// CHECK-SAME: %[[ARG0:.+]]: tensor<4x6x16x32xi8>
240252
func.func @tensor_cast_to_memref(%arg0 : tensor<4x6x16x32xi8>) ->

0 commit comments

Comments
 (0)