Skip to content

Commit 4671565

Browse files
committed
[MLIR][Bufferization] BufferResultsToOutParams: Add option to add attribute to output arguments
Adds a new pass option `add-result-attr` that will make the pass add the attribute `{bufferize.result}` to each argument that was converted from a result. To be able to test this, the pass option was added to the tablegen. To avoid collisions with the existing, manually defined option struct `BufferResultsToOutParamsOptions`, that one was renamed to `BufferResultsToOutParamsOpts`.
1 parent 5830d1a commit 4671565

File tree

4 files changed

+48
-11
lines changed

4 files changed

+48
-11
lines changed

mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ std::unique_ptr<Pass> createBufferLoopHoistingPass();
148148

149149
// Options struct for BufferResultsToOutParams pass.
150150
// Note: defined only here, not in tablegen.
151-
struct BufferResultsToOutParamsOptions {
151+
struct BufferResultsToOutParamsOpts {
152152
/// Memcpy function: Generate a memcpy between two memrefs.
153153
using MemCpyFn =
154154
std::function<LogicalResult(OpBuilder &, Location, Value, Value)>;
@@ -162,17 +162,21 @@ struct BufferResultsToOutParamsOptions {
162162
/// Memcpy function; used to create a copy between two memrefs.
163163
/// If this is empty, memref.copy is used.
164164
std::optional<MemCpyFn> memCpyFn;
165+
166+
/// If true, the pass adds a "bufferize.result" attribute to each output
167+
/// parameter.
168+
bool addResultAttribute = false;
165169
};
166170

167171
/// Creates a pass that converts memref function results to out-params.
168172
std::unique_ptr<Pass> createBufferResultsToOutParamsPass(
169-
const BufferResultsToOutParamsOptions &options = {});
173+
const BufferResultsToOutParamsOpts &options = {});
170174

171175
/// Replace buffers that are returned from a function with an out parameter.
172176
/// Also update all call sites.
173177
LogicalResult
174178
promoteBufferResultsToOutParams(ModuleOp module,
175-
const BufferResultsToOutParamsOptions &options);
179+
const BufferResultsToOutParamsOpts &options);
176180

177181
/// Creates a pass that drops memref function results that are equivalent to a
178182
/// function argument.

mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,11 @@ def BufferResultsToOutParams : Pass<"buffer-results-to-out-params", "ModuleOp">
316316
buffers for results need to be allocated in the caller. This currently only
317317
works for static shaped memrefs.
318318
}];
319+
let options = [
320+
Option<"addResultAttribute", "add-result-attr", "bool",
321+
/*default=*/"false",
322+
"Add the attribute 'bufferize.result' to all output parameters.">,
323+
];
319324
let constructor = "mlir::bufferization::createBufferResultsToOutParamsPass()";
320325
let dependentDialects = ["memref::MemRefDialect"];
321326
}

mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ namespace bufferization {
2121
} // namespace mlir
2222

2323
using namespace mlir;
24-
using MemCpyFn = bufferization::BufferResultsToOutParamsOptions::MemCpyFn;
24+
using MemCpyFn = bufferization::BufferResultsToOutParamsOpts::MemCpyFn;
2525

2626
/// Return `true` if the given MemRef type has a fully dynamic layout.
2727
static bool hasFullyDynamicLayoutMap(MemRefType type) {
@@ -47,7 +47,8 @@ static bool hasStaticIdentityLayout(MemRefType type) {
4747
// Any args appended to the entry block are added to `appendedEntryArgs`.
4848
static LogicalResult
4949
updateFuncOp(func::FuncOp func,
50-
SmallVectorImpl<BlockArgument> &appendedEntryArgs) {
50+
SmallVectorImpl<BlockArgument> &appendedEntryArgs,
51+
bool addResultAttribute) {
5152
auto functionType = func.getFunctionType();
5253

5354
// Collect information about the results will become appended arguments.
@@ -80,6 +81,10 @@ updateFuncOp(func::FuncOp func,
8081
for (int i = 0, e = erasedResultTypes.size(); i < e; ++i, ++erasedIndicesIt) {
8182
func.setArgAttrs(functionType.getNumInputs() + i,
8283
func.getResultAttrs(*erasedIndicesIt));
84+
if (addResultAttribute)
85+
func.setArgAttr(functionType.getNumInputs() + i,
86+
StringAttr::get(func.getContext(), "bufferize.result"),
87+
UnitAttr::get(func.getContext()));
8388
}
8489

8590
// Erase the results.
@@ -127,7 +132,7 @@ static LogicalResult updateReturnOps(func::FuncOp func,
127132
// temporary buffers for newly introduced out params.
128133
static LogicalResult
129134
updateCalls(ModuleOp module,
130-
const bufferization::BufferResultsToOutParamsOptions &options) {
135+
const bufferization::BufferResultsToOutParamsOpts &options) {
131136
bool didFail = false;
132137
SymbolTable symtab(module);
133138
module.walk([&](func::CallOp op) {
@@ -189,12 +194,13 @@ updateCalls(ModuleOp module,
189194

190195
LogicalResult mlir::bufferization::promoteBufferResultsToOutParams(
191196
ModuleOp module,
192-
const bufferization::BufferResultsToOutParamsOptions &options) {
197+
const bufferization::BufferResultsToOutParamsOpts &options) {
193198
for (auto func : module.getOps<func::FuncOp>()) {
194199
if (!options.filterFn(&func))
195200
continue;
196201
SmallVector<BlockArgument, 6> appendedEntryArgs;
197-
if (failed(updateFuncOp(func, appendedEntryArgs)))
202+
if (failed(
203+
updateFuncOp(func, appendedEntryArgs, options.addResultAttribute)))
198204
return failure();
199205
if (func.isExternal())
200206
continue;
@@ -218,21 +224,25 @@ struct BufferResultsToOutParamsPass
218224
: bufferization::impl::BufferResultsToOutParamsBase<
219225
BufferResultsToOutParamsPass> {
220226
explicit BufferResultsToOutParamsPass(
221-
const bufferization::BufferResultsToOutParamsOptions &options)
227+
const bufferization::BufferResultsToOutParamsOpts &options)
222228
: options(options) {}
223229

224230
void runOnOperation() override {
231+
// Convert from pass options in tablegen to BufferResultsToOutParamsOpts.
232+
if (addResultAttribute)
233+
options.addResultAttribute = true;
234+
225235
if (failed(bufferization::promoteBufferResultsToOutParams(getOperation(),
226236
options)))
227237
return signalPassFailure();
228238
}
229239

230240
private:
231-
bufferization::BufferResultsToOutParamsOptions options;
241+
bufferization::BufferResultsToOutParamsOpts options;
232242
};
233243
} // namespace
234244

235245
std::unique_ptr<Pass> mlir::bufferization::createBufferResultsToOutParamsPass(
236-
const bufferization::BufferResultsToOutParamsOptions &options) {
246+
const bufferization::BufferResultsToOutParamsOpts &options) {
237247
return std::make_unique<BufferResultsToOutParamsPass>(options);
238248
}
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
// RUN: mlir-opt -p 'builtin.module(buffer-results-to-out-params{add-result-attr})' -split-input-file -verify-diagnostics %s | FileCheck %s
2+
3+
// CHECK-LABEL: basic
4+
// CHECK-SAME: memref<f32> {bufferize.result})
5+
func.func @basic() -> (memref<f32>) {
6+
%0 = "test.source"() : () -> (memref<f32>)
7+
return %0 : memref<f32>
8+
}
9+
10+
// -----
11+
12+
// CHECK-LABEL: multiple_results
13+
// CHECK-SAME: memref<1xf32> {bufferize.result},
14+
// CHECK-SAME: memref<2xf32> {bufferize.result})
15+
func.func @multiple_results() -> (memref<1xf32>, memref<2xf32>) {
16+
%0, %1 = "test.source"() : () -> (memref<1xf32>, memref<2xf32>)
17+
return %0, %1 : memref<1xf32>, memref<2xf32>
18+
}

0 commit comments

Comments
 (0)