Skip to content

Commit be1b0a1

Browse files
committed
[MLIR][Bufferization] BufferResultsToOutParams: Add option to add attribute to output arguments
Adds a new pass option `add-result-attr` that will make the pass add the attribute `{bufferize.result}` to each argument that was converted from a result. To be able to test this, the pass option was added to the tablegen. To avoid collisions with the existing, manually defined option struct `BufferResultsToOutParamsOptions`, that one was renamed to `BufferResultsToOutParamsOpts`.
1 parent 071f72a commit be1b0a1

File tree

4 files changed

+49
-11
lines changed

4 files changed

+49
-11
lines changed

mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ std::unique_ptr<Pass> createBufferLoopHoistingPass();
148148

149149
// Options struct for BufferResultsToOutParams pass.
150150
// Note: defined only here, not in tablegen.
151-
struct BufferResultsToOutParamsOptions {
151+
struct BufferResultsToOutParamsOpts {
152152
/// Memcpy function: Generate a memcpy between two memrefs.
153153
using MemCpyFn =
154154
std::function<LogicalResult(OpBuilder &, Location, Value, Value)>;
@@ -162,17 +162,21 @@ struct BufferResultsToOutParamsOptions {
162162
/// Memcpy function; used to create a copy between two memrefs.
163163
/// If this is empty, memref.copy is used.
164164
std::optional<MemCpyFn> memCpyFn;
165+
166+
/// If true, the pass adds a "bufferize.result" attribute to each output
167+
/// parameter.
168+
bool addResultAttribute = false;
165169
};
166170

167171
/// Creates a pass that converts memref function results to out-params.
168172
std::unique_ptr<Pass> createBufferResultsToOutParamsPass(
169-
const BufferResultsToOutParamsOptions &options = {});
173+
const BufferResultsToOutParamsOpts &options = {});
170174

171175
/// Replace buffers that are returned from a function with an out parameter.
172176
/// Also update all call sites.
173177
LogicalResult
174178
promoteBufferResultsToOutParams(ModuleOp module,
175-
const BufferResultsToOutParamsOptions &options);
179+
const BufferResultsToOutParamsOpts &options);
176180

177181
/// Creates a pass that drops memref function results that are equivalent to a
178182
/// function argument.

mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,11 @@ def BufferResultsToOutParams : Pass<"buffer-results-to-out-params", "ModuleOp">
316316
buffers for results need to be allocated in the caller. This currently only
317317
works for static shaped memrefs.
318318
}];
319+
let options = [
320+
Option<"addResultAttribute", "add-result-attr", "bool",
321+
/*default=*/"false",
322+
"Add the attribute 'bufferize.result' to all output parameters.">,
323+
];
319324
let constructor = "mlir::bufferization::createBufferResultsToOutParamsPass()";
320325
let dependentDialects = ["memref::MemRefDialect"];
321326
}

mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ namespace bufferization {
2121
} // namespace mlir
2222

2323
using namespace mlir;
24-
using MemCpyFn = bufferization::BufferResultsToOutParamsOptions::MemCpyFn;
24+
using MemCpyFn = bufferization::BufferResultsToOutParamsOpts::MemCpyFn;
2525

2626
/// Return `true` if the given MemRef type has a fully dynamic layout.
2727
static bool hasFullyDynamicLayoutMap(MemRefType type) {
@@ -45,9 +45,12 @@ static bool hasStaticIdentityLayout(MemRefType type) {
4545
// Updates the func op and entry block.
4646
//
4747
// Any args appended to the entry block are added to `appendedEntryArgs`.
48+
// If `addResultAttribute` is true, adds the unit attribute `bufferize.result`
49+
// to each newly created function argument.
4850
static LogicalResult
4951
updateFuncOp(func::FuncOp func,
50-
SmallVectorImpl<BlockArgument> &appendedEntryArgs) {
52+
SmallVectorImpl<BlockArgument> &appendedEntryArgs,
53+
bool addResultAttribute) {
5154
auto functionType = func.getFunctionType();
5255

5356
// Collect information about the results will become appended arguments.
@@ -80,6 +83,10 @@ updateFuncOp(func::FuncOp func,
8083
for (int i = 0, e = erasedResultTypes.size(); i < e; ++i, ++erasedIndicesIt) {
8184
func.setArgAttrs(functionType.getNumInputs() + i,
8285
func.getResultAttrs(*erasedIndicesIt));
86+
if (addResultAttribute)
87+
func.setArgAttr(functionType.getNumInputs() + i,
88+
StringAttr::get(func.getContext(), "bufferize.result"),
89+
UnitAttr::get(func.getContext()));
8390
}
8491

8592
// Erase the results.
@@ -127,7 +134,7 @@ static LogicalResult updateReturnOps(func::FuncOp func,
127134
// temporary buffers for newly introduced out params.
128135
static LogicalResult
129136
updateCalls(ModuleOp module,
130-
const bufferization::BufferResultsToOutParamsOptions &options) {
137+
const bufferization::BufferResultsToOutParamsOpts &options) {
131138
bool didFail = false;
132139
SymbolTable symtab(module);
133140
module.walk([&](func::CallOp op) {
@@ -189,12 +196,13 @@ updateCalls(ModuleOp module,
189196

190197
LogicalResult mlir::bufferization::promoteBufferResultsToOutParams(
191198
ModuleOp module,
192-
const bufferization::BufferResultsToOutParamsOptions &options) {
199+
const bufferization::BufferResultsToOutParamsOpts &options) {
193200
for (auto func : module.getOps<func::FuncOp>()) {
194201
if (!options.filterFn(&func))
195202
continue;
196203
SmallVector<BlockArgument, 6> appendedEntryArgs;
197-
if (failed(updateFuncOp(func, appendedEntryArgs)))
204+
if (failed(
205+
updateFuncOp(func, appendedEntryArgs, options.addResultAttribute)))
198206
return failure();
199207
if (func.isExternal())
200208
continue;
@@ -218,21 +226,25 @@ struct BufferResultsToOutParamsPass
218226
: bufferization::impl::BufferResultsToOutParamsBase<
219227
BufferResultsToOutParamsPass> {
220228
explicit BufferResultsToOutParamsPass(
221-
const bufferization::BufferResultsToOutParamsOptions &options)
229+
const bufferization::BufferResultsToOutParamsOpts &options)
222230
: options(options) {}
223231

224232
void runOnOperation() override {
233+
// Convert from pass options in tablegen to BufferResultsToOutParamsOpts.
234+
if (addResultAttribute)
235+
options.addResultAttribute = true;
236+
225237
if (failed(bufferization::promoteBufferResultsToOutParams(getOperation(),
226238
options)))
227239
return signalPassFailure();
228240
}
229241

230242
private:
231-
bufferization::BufferResultsToOutParamsOptions options;
243+
bufferization::BufferResultsToOutParamsOpts options;
232244
};
233245
} // namespace
234246

235247
std::unique_ptr<Pass> mlir::bufferization::createBufferResultsToOutParamsPass(
236-
const bufferization::BufferResultsToOutParamsOptions &options) {
248+
const bufferization::BufferResultsToOutParamsOpts &options) {
237249
return std::make_unique<BufferResultsToOutParamsPass>(options);
238250
}
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
// RUN: mlir-opt -p 'builtin.module(buffer-results-to-out-params{add-result-attr})' -split-input-file -verify-diagnostics %s | FileCheck %s
2+
3+
// CHECK-LABEL: @basic({{.*}}: memref<f32> {bufferize.result})
4+
func.func @basic() -> (memref<f32>) {
5+
%0 = "test.source"() : () -> (memref<f32>)
6+
return %0 : memref<f32>
7+
}
8+
9+
// -----
10+
11+
// CHECK-LABEL: multiple_results
12+
// CHECK-SAME: memref<1xf32> {bufferize.result}
13+
// CHECK-SAME: memref<2xf32> {bufferize.result}
14+
func.func @multiple_results() -> (memref<1xf32>, memref<2xf32>) {
15+
%0, %1 = "test.source"() : () -> (memref<1xf32>, memref<2xf32>)
16+
return %0, %1 : memref<1xf32>, memref<2xf32>
17+
}

0 commit comments

Comments
 (0)