Skip to content

Commit e6048b7

Browse files
authored
[MLIR][Bufferization] BufferResultsToOutParams: Add option to add attribute to output arguments (#84320)
Adds a new pass option `add-result-attr` that will make the pass add the attribute `{bufferize.result}` to each argument that was converted from a result. This is important e.g. when later using the python bindings / execution engine to understand which arguments are actually results. To be able to test this, the pass option was added to the tablegen. To avoid collisions with the existing, manually defined option struct `BufferResultsToOutParamsOptions`, that one was renamed to `BufferResultsToOutParamsOpts`.
1 parent 071f72a commit e6048b7

File tree

4 files changed

+49
-11
lines changed

4 files changed

+49
-11
lines changed

mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ std::unique_ptr<Pass> createBufferLoopHoistingPass();
148148

149149
// Options struct for BufferResultsToOutParams pass.
150150
// Note: defined only here, not in tablegen.
151-
struct BufferResultsToOutParamsOptions {
151+
struct BufferResultsToOutParamsOpts {
152152
/// Memcpy function: Generate a memcpy between two memrefs.
153153
using MemCpyFn =
154154
std::function<LogicalResult(OpBuilder &, Location, Value, Value)>;
@@ -162,17 +162,21 @@ struct BufferResultsToOutParamsOptions {
162162
/// Memcpy function; used to create a copy between two memrefs.
163163
/// If this is empty, memref.copy is used.
164164
std::optional<MemCpyFn> memCpyFn;
165+
166+
/// If true, the pass adds a "bufferize.result" attribute to each output
167+
/// parameter.
168+
bool addResultAttribute = false;
165169
};
166170

167171
/// Creates a pass that converts memref function results to out-params.
168172
std::unique_ptr<Pass> createBufferResultsToOutParamsPass(
169-
const BufferResultsToOutParamsOptions &options = {});
173+
const BufferResultsToOutParamsOpts &options = {});
170174

171175
/// Replace buffers that are returned from a function with an out parameter.
172176
/// Also update all call sites.
173177
LogicalResult
174178
promoteBufferResultsToOutParams(ModuleOp module,
175-
const BufferResultsToOutParamsOptions &options);
179+
const BufferResultsToOutParamsOpts &options);
176180

177181
/// Creates a pass that drops memref function results that are equivalent to a
178182
/// function argument.

mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,11 @@ def BufferResultsToOutParams : Pass<"buffer-results-to-out-params", "ModuleOp">
316316
buffers for results need to be allocated in the caller. This currently only
317317
works for static shaped memrefs.
318318
}];
319+
let options = [
320+
Option<"addResultAttribute", "add-result-attr", "bool",
321+
/*default=*/"false",
322+
"Add the attribute 'bufferize.result' to all output parameters.">,
323+
];
319324
let constructor = "mlir::bufferization::createBufferResultsToOutParamsPass()";
320325
let dependentDialects = ["memref::MemRefDialect"];
321326
}

mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ namespace bufferization {
2121
} // namespace mlir
2222

2323
using namespace mlir;
24-
using MemCpyFn = bufferization::BufferResultsToOutParamsOptions::MemCpyFn;
24+
using MemCpyFn = bufferization::BufferResultsToOutParamsOpts::MemCpyFn;
2525

2626
/// Return `true` if the given MemRef type has a fully dynamic layout.
2727
static bool hasFullyDynamicLayoutMap(MemRefType type) {
@@ -45,9 +45,12 @@ static bool hasStaticIdentityLayout(MemRefType type) {
4545
// Updates the func op and entry block.
4646
//
4747
// Any args appended to the entry block are added to `appendedEntryArgs`.
48+
// If `addResultAttribute` is true, adds the unit attribute `bufferize.result`
49+
// to each newly created function argument.
4850
static LogicalResult
4951
updateFuncOp(func::FuncOp func,
50-
SmallVectorImpl<BlockArgument> &appendedEntryArgs) {
52+
SmallVectorImpl<BlockArgument> &appendedEntryArgs,
53+
bool addResultAttribute) {
5154
auto functionType = func.getFunctionType();
5255

5356
// Collect information about the results will become appended arguments.
@@ -80,6 +83,10 @@ updateFuncOp(func::FuncOp func,
8083
for (int i = 0, e = erasedResultTypes.size(); i < e; ++i, ++erasedIndicesIt) {
8184
func.setArgAttrs(functionType.getNumInputs() + i,
8285
func.getResultAttrs(*erasedIndicesIt));
86+
if (addResultAttribute)
87+
func.setArgAttr(functionType.getNumInputs() + i,
88+
StringAttr::get(func.getContext(), "bufferize.result"),
89+
UnitAttr::get(func.getContext()));
8390
}
8491

8592
// Erase the results.
@@ -127,7 +134,7 @@ static LogicalResult updateReturnOps(func::FuncOp func,
127134
// temporary buffers for newly introduced out params.
128135
static LogicalResult
129136
updateCalls(ModuleOp module,
130-
const bufferization::BufferResultsToOutParamsOptions &options) {
137+
const bufferization::BufferResultsToOutParamsOpts &options) {
131138
bool didFail = false;
132139
SymbolTable symtab(module);
133140
module.walk([&](func::CallOp op) {
@@ -189,12 +196,13 @@ updateCalls(ModuleOp module,
189196

190197
LogicalResult mlir::bufferization::promoteBufferResultsToOutParams(
191198
ModuleOp module,
192-
const bufferization::BufferResultsToOutParamsOptions &options) {
199+
const bufferization::BufferResultsToOutParamsOpts &options) {
193200
for (auto func : module.getOps<func::FuncOp>()) {
194201
if (!options.filterFn(&func))
195202
continue;
196203
SmallVector<BlockArgument, 6> appendedEntryArgs;
197-
if (failed(updateFuncOp(func, appendedEntryArgs)))
204+
if (failed(
205+
updateFuncOp(func, appendedEntryArgs, options.addResultAttribute)))
198206
return failure();
199207
if (func.isExternal())
200208
continue;
@@ -218,21 +226,25 @@ struct BufferResultsToOutParamsPass
218226
: bufferization::impl::BufferResultsToOutParamsBase<
219227
BufferResultsToOutParamsPass> {
220228
explicit BufferResultsToOutParamsPass(
221-
const bufferization::BufferResultsToOutParamsOptions &options)
229+
const bufferization::BufferResultsToOutParamsOpts &options)
222230
: options(options) {}
223231

224232
void runOnOperation() override {
233+
// Convert from pass options in tablegen to BufferResultsToOutParamsOpts.
234+
if (addResultAttribute)
235+
options.addResultAttribute = true;
236+
225237
if (failed(bufferization::promoteBufferResultsToOutParams(getOperation(),
226238
options)))
227239
return signalPassFailure();
228240
}
229241

230242
private:
231-
bufferization::BufferResultsToOutParamsOptions options;
243+
bufferization::BufferResultsToOutParamsOpts options;
232244
};
233245
} // namespace
234246

235247
std::unique_ptr<Pass> mlir::bufferization::createBufferResultsToOutParamsPass(
236-
const bufferization::BufferResultsToOutParamsOptions &options) {
248+
const bufferization::BufferResultsToOutParamsOpts &options) {
237249
return std::make_unique<BufferResultsToOutParamsPass>(options);
238250
}
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
// RUN: mlir-opt -p 'builtin.module(buffer-results-to-out-params{add-result-attr})' -split-input-file -verify-diagnostics %s | FileCheck %s
2+
3+
// CHECK-LABEL: @basic({{.*}}: memref<f32> {bufferize.result})
4+
func.func @basic() -> (memref<f32>) {
5+
%0 = "test.source"() : () -> (memref<f32>)
6+
return %0 : memref<f32>
7+
}
8+
9+
// -----
10+
11+
// CHECK-LABEL: multiple_results
12+
// CHECK-SAME: memref<1xf32> {bufferize.result}
13+
// CHECK-SAME: memref<2xf32> {bufferize.result}
14+
func.func @multiple_results() -> (memref<1xf32>, memref<2xf32>) {
15+
%0, %1 = "test.source"() : () -> (memref<1xf32>, memref<2xf32>)
16+
return %0, %1 : memref<1xf32>, memref<2xf32>
17+
}

0 commit comments

Comments
 (0)