@@ -22,6 +22,7 @@ namespace bufferization {
22
22
} // namespace mlir
23
23
24
24
using namespace mlir ;
25
+ using AllocationFn = bufferization::BufferResultsToOutParamsOpts::AllocationFn;
25
26
using MemCpyFn = bufferization::BufferResultsToOutParamsOpts::MemCpyFn;
26
27
27
28
// / Return `true` if the given MemRef type has a fully dynamic layout.
@@ -141,9 +142,8 @@ static LogicalResult updateReturnOps(func::FuncOp func,
141
142
142
143
// Updates all CallOps in the scope of the given ModuleOp by allocating
143
144
// temporary buffers for newly introduced out params.
144
- static LogicalResult
145
- updateCalls (ModuleOp module ,
146
- const bufferization::BufferResultsToOutParamsOpts &options) {
145
+ static LogicalResult updateCalls (ModuleOp module , AllocationFn allocationFn,
146
+ std::function<bool (func::FuncOp *)> filterFn) {
147
147
bool didFail = false ;
148
148
SymbolTable symtab (module );
149
149
module .walk ([&](func::CallOp op) {
@@ -154,7 +154,7 @@ updateCalls(ModuleOp module,
154
154
didFail = true ;
155
155
return ;
156
156
}
157
- if (!options. filterFn (&callee))
157
+ if (!filterFn (&callee))
158
158
return ;
159
159
SmallVector<Value, 6 > replaceWithNewCallResults;
160
160
SmallVector<Value, 6 > replaceWithOutParams;
@@ -177,7 +177,13 @@ updateCalls(ModuleOp module,
177
177
auto allocType =
178
178
MemRefType::get (memrefType.getShape (), memrefType.getElementType (),
179
179
AffineMap (), memrefType.getMemorySpace ());
180
- Value outParam = builder.create <memref::AllocOp>(op.getLoc (), allocType);
180
+ auto maybeOutParam = allocationFn (builder, op.getLoc (), allocType);
181
+ if (failed (maybeOutParam)) {
182
+ op.emitError () << " failed to create allocation op" ;
183
+ didFail = true ;
184
+ return ;
185
+ }
186
+ Value outParam = maybeOutParam.value ();
181
187
if (!hasStaticIdentityLayout (memrefType)) {
182
188
// Layout maps are already checked in `updateFuncOp`.
183
189
assert (hasFullyDynamicLayoutMap (memrefType) &&
@@ -226,7 +232,13 @@ LogicalResult mlir::bufferization::promoteBufferResultsToOutParams(
226
232
return failure ();
227
233
}
228
234
}
229
- if (failed (updateCalls (module , options)))
235
+ auto defaultAllocationFn = [](OpBuilder &builder, Location loc,
236
+ MemRefType type) {
237
+ return builder.create <memref::AllocOp>(loc, type).getResult ();
238
+ };
239
+ if (failed (updateCalls (module ,
240
+ options.allocationFn .value_or (defaultAllocationFn),
241
+ options.filterFn )))
230
242
return failure ();
231
243
return success ();
232
244
}
0 commit comments