@@ -107,10 +107,9 @@ updateFuncOp(func::FuncOp func,
107
107
// Updates all ReturnOps in the scope of the given func::FuncOp by either
108
108
// keeping them as return values or copying the associated buffer contents into
109
109
// the given out-params.
110
- static LogicalResult updateReturnOps (func::FuncOp func,
111
- ArrayRef<BlockArgument> appendedEntryArgs,
112
- MemCpyFn memCpyFn,
113
- bool hoistStaticAllocs) {
110
+ static LogicalResult
111
+ updateReturnOps (func::FuncOp func, ArrayRef<BlockArgument> appendedEntryArgs,
112
+ const bufferization::BufferResultsToOutParamsOpts &options) {
114
113
auto res = func.walk ([&](func::ReturnOp op) {
115
114
SmallVector<Value, 6 > copyIntoOutParams;
116
115
SmallVector<Value, 6 > keepAsReturnOperands;
@@ -122,14 +121,14 @@ static LogicalResult updateReturnOps(func::FuncOp func,
122
121
}
123
122
OpBuilder builder (op);
124
123
for (auto [orig, arg] : llvm::zip (copyIntoOutParams, appendedEntryArgs)) {
125
- if (hoistStaticAllocs &&
124
+ if (options. hoistStaticAllocs &&
126
125
isa_and_nonnull<bufferization::AllocationOpInterface>(
127
126
orig.getDefiningOp ()) &&
128
127
mlir::cast<MemRefType>(orig.getType ()).hasStaticShape ()) {
129
128
orig.replaceAllUsesWith (arg);
130
129
orig.getDefiningOp ()->erase ();
131
130
} else {
132
- if (failed (memCpyFn (builder, op.getLoc (), orig, arg)))
131
+ if (failed (options. memCpyFn (builder, op.getLoc (), orig, arg)))
133
132
return WalkResult::interrupt ();
134
133
}
135
134
}
@@ -142,8 +141,9 @@ static LogicalResult updateReturnOps(func::FuncOp func,
142
141
143
142
// Updates all CallOps in the scope of the given ModuleOp by allocating
144
143
// temporary buffers for newly introduced out params.
145
- static LogicalResult updateCalls (ModuleOp module , AllocationFn allocationFn,
146
- std::function<bool (func::FuncOp *)> filterFn) {
144
+ static LogicalResult
145
+ updateCalls (ModuleOp module ,
146
+ const bufferization::BufferResultsToOutParamsOpts &options) {
147
147
bool didFail = false ;
148
148
SymbolTable symtab (module );
149
149
module .walk ([&](func::CallOp op) {
@@ -154,7 +154,7 @@ static LogicalResult updateCalls(ModuleOp module, AllocationFn allocationFn,
154
154
didFail = true ;
155
155
return ;
156
156
}
157
- if (!filterFn (&callee))
157
+ if (!options. filterFn (&callee))
158
158
return ;
159
159
SmallVector<Value, 6 > replaceWithNewCallResults;
160
160
SmallVector<Value, 6 > replaceWithOutParams;
@@ -177,7 +177,8 @@ static LogicalResult updateCalls(ModuleOp module, AllocationFn allocationFn,
177
177
auto allocType =
178
178
MemRefType::get (memrefType.getShape (), memrefType.getElementType (),
179
179
AffineMap (), memrefType.getMemorySpace ());
180
- auto maybeOutParam = allocationFn (builder, op.getLoc (), allocType);
180
+ auto maybeOutParam =
181
+ options.allocationFn (builder, op.getLoc (), allocType);
181
182
if (failed (maybeOutParam)) {
182
183
op.emitError () << " failed to create allocation op" ;
183
184
didFail = true ;
@@ -221,24 +222,11 @@ LogicalResult mlir::bufferization::promoteBufferResultsToOutParams(
221
222
return failure ();
222
223
if (func.isExternal ())
223
224
continue ;
224
- auto defaultMemCpyFn = [](OpBuilder &builder, Location loc, Value from,
225
- Value to) {
226
- builder.create <memref::CopyOp>(loc, from, to);
227
- return success ();
228
- };
229
- if (failed (updateReturnOps (func, appendedEntryArgs,
230
- options.memCpyFn .value_or (defaultMemCpyFn),
231
- options.hoistStaticAllocs ))) {
225
+ if (failed (updateReturnOps (func, appendedEntryArgs, options))) {
232
226
return failure ();
233
227
}
234
228
}
235
- auto defaultAllocationFn = [](OpBuilder &builder, Location loc,
236
- MemRefType type) {
237
- return builder.create <memref::AllocOp>(loc, type).getResult ();
238
- };
239
- if (failed (updateCalls (module ,
240
- options.allocationFn .value_or (defaultAllocationFn),
241
- options.filterFn )))
229
+ if (failed (updateCalls (module , options)))
242
230
return failure ();
243
231
return success ();
244
232
}
0 commit comments