@@ -21,6 +21,7 @@ namespace bufferization {
21
21
} // namespace mlir
22
22
23
23
using namespace mlir ;
24
+ using MemCpyFn = bufferization::BufferResultsToOutParamsOptions::MemCpyFn;
24
25
25
26
// / Return `true` if the given MemRef type has a fully dynamic layout.
26
27
static bool hasFullyDynamicLayoutMap (MemRefType type) {
@@ -97,9 +98,10 @@ updateFuncOp(func::FuncOp func,
97
98
// Updates all ReturnOps in the scope of the given func::FuncOp by either
98
99
// keeping them as return values or copying the associated buffer contents into
99
100
// the given out-params.
100
- static void updateReturnOps (func::FuncOp func,
101
- ArrayRef<BlockArgument> appendedEntryArgs) {
102
- func.walk ([&](func::ReturnOp op) {
101
+ static LogicalResult updateReturnOps (func::FuncOp func,
102
+ ArrayRef<BlockArgument> appendedEntryArgs,
103
+ MemCpyFn memCpyFn) {
104
+ auto res = func.walk ([&](func::ReturnOp op) {
103
105
SmallVector<Value, 6 > copyIntoOutParams;
104
106
SmallVector<Value, 6 > keepAsReturnOperands;
105
107
for (Value operand : op.getOperands ()) {
@@ -109,12 +111,16 @@ static void updateReturnOps(func::FuncOp func,
109
111
keepAsReturnOperands.push_back (operand);
110
112
}
111
113
OpBuilder builder (op);
112
- for (auto t : llvm::zip (copyIntoOutParams, appendedEntryArgs))
113
- builder.create <memref::CopyOp>(op.getLoc (), std::get<0 >(t),
114
- std::get<1 >(t));
114
+ for (auto t : llvm::zip (copyIntoOutParams, appendedEntryArgs)) {
115
+ if (failed (
116
+ memCpyFn (builder, op.getLoc (), std::get<0 >(t), std::get<1 >(t))))
117
+ return WalkResult::interrupt ();
118
+ }
115
119
builder.create <func::ReturnOp>(op.getLoc (), keepAsReturnOperands);
116
120
op.erase ();
121
+ return WalkResult::advance ();
117
122
});
123
+ return failure (res.wasInterrupted ());
118
124
}
119
125
120
126
// Updates all CallOps in the scope of the given ModuleOp by allocating
@@ -192,7 +198,15 @@ LogicalResult mlir::bufferization::promoteBufferResultsToOutParams(
192
198
return failure ();
193
199
if (func.isExternal ())
194
200
continue ;
195
- updateReturnOps (func, appendedEntryArgs);
201
+ auto defaultMemCpyFn = [](OpBuilder &builder, Location loc, Value from,
202
+ Value to) {
203
+ builder.create <memref::CopyOp>(loc, from, to);
204
+ return success ();
205
+ };
206
+ if (failed (updateReturnOps (func, appendedEntryArgs,
207
+ options.memCpyFn .value_or (defaultMemCpyFn)))) {
208
+ return failure ();
209
+ }
196
210
}
197
211
if (failed (updateCalls (module , options)))
198
212
return failure ();
0 commit comments