Skip to content

[MLIR][Bufferization] BufferResultsToOutParams: Add option to add attribute to output arguments #84320

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ std::unique_ptr<Pass> createBufferLoopHoistingPass();

// Options struct for BufferResultsToOutParams pass.
// Note: defined only here, not in tablegen.
struct BufferResultsToOutParamsOptions {
struct BufferResultsToOutParamsOpts {
/// Memcpy function: Generate a memcpy between two memrefs.
using MemCpyFn =
std::function<LogicalResult(OpBuilder &, Location, Value, Value)>;
Expand All @@ -162,17 +162,21 @@ struct BufferResultsToOutParamsOptions {
/// Memcpy function; used to create a copy between two memrefs.
/// If this is empty, memref.copy is used.
std::optional<MemCpyFn> memCpyFn;

/// If true, the pass adds a "bufferize.result" attribute to each output
/// parameter.
bool addResultAttribute = false;
};

/// Creates a pass that converts memref function results to out-params.
std::unique_ptr<Pass> createBufferResultsToOutParamsPass(
const BufferResultsToOutParamsOptions &options = {});
const BufferResultsToOutParamsOpts &options = {});

/// Replace buffers that are returned from a function with an out parameter.
/// Also update all call sites.
LogicalResult
promoteBufferResultsToOutParams(ModuleOp module,
const BufferResultsToOutParamsOptions &options);
const BufferResultsToOutParamsOpts &options);

/// Creates a pass that drops memref function results that are equivalent to a
/// function argument.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,11 @@ def BufferResultsToOutParams : Pass<"buffer-results-to-out-params", "ModuleOp">
buffers for results need to be allocated in the caller. This currently only
works for static shaped memrefs.
}];
let options = [
Option<"addResultAttribute", "add-result-attr", "bool",
/*default=*/"false",
"Add the attribute 'bufferize.result' to all output parameters.">,
];
let constructor = "mlir::bufferization::createBufferResultsToOutParamsPass()";
let dependentDialects = ["memref::MemRefDialect"];
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ namespace bufferization {
} // namespace mlir

using namespace mlir;
using MemCpyFn = bufferization::BufferResultsToOutParamsOptions::MemCpyFn;
using MemCpyFn = bufferization::BufferResultsToOutParamsOpts::MemCpyFn;

/// Return `true` if the given MemRef type has a fully dynamic layout.
static bool hasFullyDynamicLayoutMap(MemRefType type) {
Expand All @@ -45,9 +45,12 @@ static bool hasStaticIdentityLayout(MemRefType type) {
// Updates the func op and entry block.
//
// Any args appended to the entry block are added to `appendedEntryArgs`.
// If `addResultAttribute` is true, adds the unit attribute `bufferize.result`
// to each newly created function argument.
static LogicalResult
updateFuncOp(func::FuncOp func,
SmallVectorImpl<BlockArgument> &appendedEntryArgs) {
SmallVectorImpl<BlockArgument> &appendedEntryArgs,
bool addResultAttribute) {
auto functionType = func.getFunctionType();

// Collect information about the results will become appended arguments.
Expand Down Expand Up @@ -80,6 +83,10 @@ updateFuncOp(func::FuncOp func,
for (int i = 0, e = erasedResultTypes.size(); i < e; ++i, ++erasedIndicesIt) {
func.setArgAttrs(functionType.getNumInputs() + i,
func.getResultAttrs(*erasedIndicesIt));
if (addResultAttribute)
func.setArgAttr(functionType.getNumInputs() + i,
StringAttr::get(func.getContext(), "bufferize.result"),
UnitAttr::get(func.getContext()));
}

// Erase the results.
Expand Down Expand Up @@ -127,7 +134,7 @@ static LogicalResult updateReturnOps(func::FuncOp func,
// temporary buffers for newly introduced out params.
static LogicalResult
updateCalls(ModuleOp module,
const bufferization::BufferResultsToOutParamsOptions &options) {
const bufferization::BufferResultsToOutParamsOpts &options) {
bool didFail = false;
SymbolTable symtab(module);
module.walk([&](func::CallOp op) {
Expand Down Expand Up @@ -189,12 +196,13 @@ updateCalls(ModuleOp module,

LogicalResult mlir::bufferization::promoteBufferResultsToOutParams(
ModuleOp module,
const bufferization::BufferResultsToOutParamsOptions &options) {
const bufferization::BufferResultsToOutParamsOpts &options) {
for (auto func : module.getOps<func::FuncOp>()) {
if (!options.filterFn(&func))
continue;
SmallVector<BlockArgument, 6> appendedEntryArgs;
if (failed(updateFuncOp(func, appendedEntryArgs)))
if (failed(
updateFuncOp(func, appendedEntryArgs, options.addResultAttribute)))
return failure();
if (func.isExternal())
continue;
Expand All @@ -218,21 +226,25 @@ struct BufferResultsToOutParamsPass
: bufferization::impl::BufferResultsToOutParamsBase<
BufferResultsToOutParamsPass> {
explicit BufferResultsToOutParamsPass(
const bufferization::BufferResultsToOutParamsOptions &options)
const bufferization::BufferResultsToOutParamsOpts &options)
: options(options) {}

void runOnOperation() override {
// Convert from pass options in tablegen to BufferResultsToOutParamsOpts.
if (addResultAttribute)
options.addResultAttribute = true;

if (failed(bufferization::promoteBufferResultsToOutParams(getOperation(),
options)))
return signalPassFailure();
}

private:
bufferization::BufferResultsToOutParamsOptions options;
bufferization::BufferResultsToOutParamsOpts options;
};
} // namespace

std::unique_ptr<Pass> mlir::bufferization::createBufferResultsToOutParamsPass(
const bufferization::BufferResultsToOutParamsOptions &options) {
const bufferization::BufferResultsToOutParamsOpts &options) {
return std::make_unique<BufferResultsToOutParamsPass>(options);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// RUN: mlir-opt -p 'builtin.module(buffer-results-to-out-params{add-result-attr})' -split-input-file -verify-diagnostics %s | FileCheck %s

// CHECK-LABEL: @basic({{.*}}: memref<f32> {bufferize.result})
func.func @basic() -> (memref<f32>) {
%0 = "test.source"() : () -> (memref<f32>)
return %0 : memref<f32>
}

// -----

// CHECK-LABEL: multiple_results
// CHECK-SAME: memref<1xf32> {bufferize.result}
// CHECK-SAME: memref<2xf32> {bufferize.result}
func.func @multiple_results() -> (memref<1xf32>, memref<2xf32>) {
%0, %1 = "test.source"() : () -> (memref<1xf32>, memref<2xf32>)
return %0, %1 : memref<1xf32>, memref<2xf32>
}