Skip to content

Commit 7e133eb

Browse files
committed
[mlir][bufferize] Add filterFn option to BufferResultsToOutParams
This allows users to restrict the transformation to a subset of the functions in a module. For example, a user might want to apply the transformation to a module's entry point, but not to the calls in the module because those calls might refer to external C functions outside of their control. Reviewed By: springerm, nicolasvasilache Differential Revision: https://reviews.llvm.org/D137264
1 parent e2dd633 commit 7e133eb

File tree

2 files changed

+46
-10
lines changed

2 files changed

+46
-10
lines changed

mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,25 @@ std::unique_ptr<Pass> createBufferHoistingPass();
3535
/// reallocations inside of loops.
3636
std::unique_ptr<Pass> createBufferLoopHoistingPass();
3737

38+
// Options struct for BufferResultsToOutParams pass.
39+
// Note: defined only here, not in tablegen.
40+
struct BufferResultsToOutParamsOptions {
41+
// Filter function; returns true if the function should be converted.
42+
// Defaults to true, i.e. all functions are converted.
43+
llvm::function_ref<bool(func::FuncOp *)> filterFn = [](func::FuncOp *func) {
44+
return true;
45+
};
46+
};
47+
3848
/// Creates a pass that converts memref function results to out-params.
39-
std::unique_ptr<Pass> createBufferResultsToOutParamsPass();
49+
std::unique_ptr<Pass> createBufferResultsToOutParamsPass(
50+
const BufferResultsToOutParamsOptions &options = {});
4051

4152
/// Replace buffers that are returned from a function with an out parameter.
4253
/// Also update all call sites.
43-
LogicalResult promoteBufferResultsToOutParams(ModuleOp module);
54+
LogicalResult
55+
promoteBufferResultsToOutParams(ModuleOp module,
56+
const BufferResultsToOutParamsOptions &options);
4457

4558
/// Creates a pass that drops memref function results that are equivalent to a
4659
/// function argument.

mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -119,9 +119,21 @@ static void updateReturnOps(func::FuncOp func,
119119

120120
// Updates all CallOps in the scope of the given ModuleOp by allocating
121121
// temporary buffers for newly introduced out params.
122-
static LogicalResult updateCalls(ModuleOp module) {
122+
static LogicalResult
123+
updateCalls(ModuleOp module,
124+
const bufferization::BufferResultsToOutParamsOptions &options) {
123125
bool didFail = false;
126+
SymbolTable symtab(module);
124127
module.walk([&](func::CallOp op) {
128+
auto callee = symtab.lookup<func::FuncOp>(op.getCallee());
129+
if (!callee) {
130+
op.emitError() << "cannot find callee '" << op.getCallee() << "' in "
131+
<< "symbol table";
132+
didFail = true;
133+
return;
134+
}
135+
if (!options.filterFn(&callee))
136+
return;
125137
SmallVector<Value, 6> replaceWithNewCallResults;
126138
SmallVector<Value, 6> replaceWithOutParams;
127139
for (OpResult result : op.getResults()) {
@@ -169,17 +181,20 @@ static LogicalResult updateCalls(ModuleOp module) {
169181
return failure(didFail);
170182
}
171183

172-
LogicalResult
173-
mlir::bufferization::promoteBufferResultsToOutParams(ModuleOp module) {
184+
LogicalResult mlir::bufferization::promoteBufferResultsToOutParams(
185+
ModuleOp module,
186+
const bufferization::BufferResultsToOutParamsOptions &options) {
174187
for (auto func : module.getOps<func::FuncOp>()) {
188+
if (!options.filterFn(&func))
189+
continue;
175190
SmallVector<BlockArgument, 6> appendedEntryArgs;
176191
if (failed(updateFuncOp(func, appendedEntryArgs)))
177192
return failure();
178193
if (func.isExternal())
179194
continue;
180195
updateReturnOps(func, appendedEntryArgs);
181196
}
182-
if (failed(updateCalls(module)))
197+
if (failed(updateCalls(module, options)))
183198
return failure();
184199
return success();
185200
}
@@ -188,14 +203,22 @@ namespace {
188203
struct BufferResultsToOutParamsPass
189204
: bufferization::impl::BufferResultsToOutParamsBase<
190205
BufferResultsToOutParamsPass> {
206+
explicit BufferResultsToOutParamsPass(
207+
const bufferization::BufferResultsToOutParamsOptions &options)
208+
: options(options) {}
209+
191210
void runOnOperation() override {
192-
if (failed(bufferization::promoteBufferResultsToOutParams(getOperation())))
211+
if (failed(bufferization::promoteBufferResultsToOutParams(getOperation(),
212+
options)))
193213
return signalPassFailure();
194214
}
215+
216+
private:
217+
bufferization::BufferResultsToOutParamsOptions options;
195218
};
196219
} // namespace
197220

198-
std::unique_ptr<Pass>
199-
mlir::bufferization::createBufferResultsToOutParamsPass() {
200-
return std::make_unique<BufferResultsToOutParamsPass>();
221+
std::unique_ptr<Pass> mlir::bufferization::createBufferResultsToOutParamsPass(
222+
const bufferization::BufferResultsToOutParamsOptions &options) {
223+
return std::make_unique<BufferResultsToOutParamsPass>(options);
201224
}

0 commit comments

Comments
 (0)