Skip to content

Commit eaa4b6c

Browse files
authored
[mlir][bufferization] Clone simplify fails when input and result type not cast compatiable (#71310)
The simplify of bufferization.clone generates a memref.cast op, but the checks in simplify do not verify whether the operand types and return types of clone op is compatiable, leading to errors. This patch addresses this issue.
1 parent 6fdc2ce commit eaa4b6c

File tree

2 files changed

+17
-0
lines changed

2 files changed

+17
-0
lines changed

mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,11 @@ struct SimplifyClones : public OpRewritePattern<CloneOp> {
457457
}
458458

459459
Value source = cloneOp.getInput();
460+
if (source.getType() != cloneOp.getType() &&
461+
!memref::CastOp::areCastCompatible({source.getType()},
462+
{cloneOp.getType()}))
463+
return failure();
464+
460465
// Aims to find the dealloc op for the canonical source
461466
// which otherwise could prevent removal of unnecessary allocs.
462467
Value canonicalSource = source;

mlir/test/Dialect/Bufferization/canonicalize.mlir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,18 @@ func.func @clone_and_cast(%arg0: memref<?xf32>) -> memref<32xf32> {
156156

157157
// -----
158158

159+
// CHECK-LABEL: @clone_incompatible
160+
func.func @clone_incompatible(%arg0: memref<32xf32, strided<[2]>>) -> memref<32xf32> {
161+
%0 = bufferization.clone %arg0 : memref<32xf32, strided<[2]>> to memref<32xf32>
162+
memref.dealloc %arg0 : memref<32xf32, strided<[2]>>
163+
return %0 : memref<32xf32>
164+
}
165+
// CHECK-SAME: %[[ARG:.*]]: memref<32xf32, strided<[2]>>
166+
// CHECK-NEXT: bufferization.clone %[[ARG]] : memref<32xf32, strided<[2]>> to memref<32xf32>
167+
// CHECK-NOT: memref.cast
168+
169+
// -----
170+
159171
// CHECK-LABEL: @alias_is_freed
160172
func.func @alias_is_freed(%arg0 : memref<?xf32>) {
161173
%0 = memref.cast %arg0 : memref<?xf32> to memref<32xf32>

0 commit comments

Comments
 (0)