Skip to content

Commit 26645ae

Browse files
author
Menooker
authored
[mlir][memref] Fix hoist-static-allocs option of buffer-results-to-out-params when function parameters are returned (#102093)
buffer-results-to-out-params pass will have a nullptr-referencing error when hoist-static-allocs option is on, when the return value of a function is a parameter of the function. This PR fixes this issue.
1 parent d65ff3e commit 26645ae

File tree

2 files changed

+17
-2
lines changed

2 files changed

+17
-2
lines changed

mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,8 @@ static LogicalResult updateReturnOps(func::FuncOp func,
120120
}
121121
OpBuilder builder(op);
122122
for (auto [orig, arg] : llvm::zip(copyIntoOutParams, appendedEntryArgs)) {
123-
if (hoistStaticAllocs && isa<memref::AllocOp>(orig.getDefiningOp()) &&
123+
if (hoistStaticAllocs &&
124+
isa_and_nonnull<memref::AllocOp>(orig.getDefiningOp()) &&
124125
mlir::cast<MemRefType>(orig.getType()).hasStaticShape()) {
125126
orig.replaceAllUsesWith(arg);
126127
orig.getDefiningOp()->erase();

mlir/test/Transforms/buffer-results-to-out-params-elim.mlir

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,4 +34,18 @@ func.func @basic_dynamic(%d: index) -> (memref<?xf32>) {
3434
%b = memref.alloc(%d) : memref<?xf32>
3535
"test.source"(%b) : (memref<?xf32>) -> ()
3636
return %b : memref<?xf32>
37-
}
37+
}
38+
39+
// -----
40+
41+
// no change due to writing to func args
42+
// CHECK-LABEL: func @return_arg(
43+
// CHECK-SAME: %[[ARG0:.*]]: memref<128x256xf32>, %[[ARG1:.*]]: memref<128x256xf32>, %[[ARG2:.*]]: memref<128x256xf32>) {
44+
// CHECK: "test.source"(%[[ARG0]], %[[ARG1]])
45+
// CHECK: memref.copy
46+
// CHECK: return
47+
// CHECK: }
48+
func.func @return_arg(%arg0: memref<128x256xf32>, %arg1: memref<128x256xf32>) -> memref<128x256xf32> {
49+
"test.source"(%arg0, %arg1) : (memref<128x256xf32>, memref<128x256xf32>) -> ()
50+
return %arg0 : memref<128x256xf32>
51+
}

0 commit comments

Comments
 (0)