-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR][Bufferization] BufferResultsToOutParams: Add option to add attribute to output arguments #84320
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir Author: Matthias Gehre (mgehre-amd) ChangesAdds a new pass option To be able to test this, the pass option was added to the tablegen. To avoid collisions with the existing, manually defined option struct Full diff: https://github.com/llvm/llvm-project/pull/84320.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
index 809f03407258a8..a729bc99b987cd 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
@@ -148,7 +148,7 @@ std::unique_ptr<Pass> createBufferLoopHoistingPass();
// Options struct for BufferResultsToOutParams pass.
// Note: defined only here, not in tablegen.
-struct BufferResultsToOutParamsOptions {
+struct BufferResultsToOutParamsOpts {
/// Memcpy function: Generate a memcpy between two memrefs.
using MemCpyFn =
std::function<LogicalResult(OpBuilder &, Location, Value, Value)>;
@@ -162,17 +162,21 @@ struct BufferResultsToOutParamsOptions {
/// Memcpy function; used to create a copy between two memrefs.
/// If this is empty, memref.copy is used.
std::optional<MemCpyFn> memCpyFn;
+
+ /// If true, the pass adds a "bufferize.result" attribute to each output
+ /// parameter.
+ bool addResultAttribute = false;
};
/// Creates a pass that converts memref function results to out-params.
std::unique_ptr<Pass> createBufferResultsToOutParamsPass(
- const BufferResultsToOutParamsOptions &options = {});
+ const BufferResultsToOutParamsOpts &options = {});
/// Replace buffers that are returned from a function with an out parameter.
/// Also update all call sites.
LogicalResult
promoteBufferResultsToOutParams(ModuleOp module,
- const BufferResultsToOutParamsOptions &options);
+ const BufferResultsToOutParamsOpts &options);
/// Creates a pass that drops memref function results that are equivalent to a
/// function argument.
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
index e01f36b8daa18d..1c3cdec81a39e0 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
@@ -316,6 +316,11 @@ def BufferResultsToOutParams : Pass<"buffer-results-to-out-params", "ModuleOp">
buffers for results need to be allocated in the caller. This currently only
works for static shaped memrefs.
}];
+ let options = [
+ Option<"addResultAttribute", "add-result-attr", "bool",
+ /*default=*/"false",
+ "Add the attribute 'bufferize.result' to all output parameters.">,
+ ];
let constructor = "mlir::bufferization::createBufferResultsToOutParamsPass()";
let dependentDialects = ["memref::MemRefDialect"];
}
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
index 930f035339c1d3..5ab347066c90cb 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
@@ -21,7 +21,7 @@ namespace bufferization {
} // namespace mlir
using namespace mlir;
-using MemCpyFn = bufferization::BufferResultsToOutParamsOptions::MemCpyFn;
+using MemCpyFn = bufferization::BufferResultsToOutParamsOpts::MemCpyFn;
/// Return `true` if the given MemRef type has a fully dynamic layout.
static bool hasFullyDynamicLayoutMap(MemRefType type) {
@@ -47,7 +47,8 @@ static bool hasStaticIdentityLayout(MemRefType type) {
// Any args appended to the entry block are added to `appendedEntryArgs`.
static LogicalResult
updateFuncOp(func::FuncOp func,
- SmallVectorImpl<BlockArgument> &appendedEntryArgs) {
+ SmallVectorImpl<BlockArgument> &appendedEntryArgs,
+ bool addResultAttribute) {
auto functionType = func.getFunctionType();
// Collect information about the results will become appended arguments.
@@ -80,6 +81,10 @@ updateFuncOp(func::FuncOp func,
for (int i = 0, e = erasedResultTypes.size(); i < e; ++i, ++erasedIndicesIt) {
func.setArgAttrs(functionType.getNumInputs() + i,
func.getResultAttrs(*erasedIndicesIt));
+ if (addResultAttribute)
+ func.setArgAttr(functionType.getNumInputs() + i,
+ StringAttr::get(func.getContext(), "bufferize.result"),
+ UnitAttr::get(func.getContext()));
}
// Erase the results.
@@ -127,7 +132,7 @@ static LogicalResult updateReturnOps(func::FuncOp func,
// temporary buffers for newly introduced out params.
static LogicalResult
updateCalls(ModuleOp module,
- const bufferization::BufferResultsToOutParamsOptions &options) {
+ const bufferization::BufferResultsToOutParamsOpts &options) {
bool didFail = false;
SymbolTable symtab(module);
module.walk([&](func::CallOp op) {
@@ -189,12 +194,13 @@ updateCalls(ModuleOp module,
LogicalResult mlir::bufferization::promoteBufferResultsToOutParams(
ModuleOp module,
- const bufferization::BufferResultsToOutParamsOptions &options) {
+ const bufferization::BufferResultsToOutParamsOpts &options) {
for (auto func : module.getOps<func::FuncOp>()) {
if (!options.filterFn(&func))
continue;
SmallVector<BlockArgument, 6> appendedEntryArgs;
- if (failed(updateFuncOp(func, appendedEntryArgs)))
+ if (failed(
+ updateFuncOp(func, appendedEntryArgs, options.addResultAttribute)))
return failure();
if (func.isExternal())
continue;
@@ -218,21 +224,25 @@ struct BufferResultsToOutParamsPass
: bufferization::impl::BufferResultsToOutParamsBase<
BufferResultsToOutParamsPass> {
explicit BufferResultsToOutParamsPass(
- const bufferization::BufferResultsToOutParamsOptions &options)
+ const bufferization::BufferResultsToOutParamsOpts &options)
: options(options) {}
void runOnOperation() override {
+ // Convert from pass options in tablegen to BufferResultsToOutParamsOpts.
+ if (addResultAttribute)
+ options.addResultAttribute = true;
+
if (failed(bufferization::promoteBufferResultsToOutParams(getOperation(),
options)))
return signalPassFailure();
}
private:
- bufferization::BufferResultsToOutParamsOptions options;
+ bufferization::BufferResultsToOutParamsOpts options;
};
} // namespace
std::unique_ptr<Pass> mlir::bufferization::createBufferResultsToOutParamsPass(
- const bufferization::BufferResultsToOutParamsOptions &options) {
+ const bufferization::BufferResultsToOutParamsOpts &options) {
return std::make_unique<BufferResultsToOutParamsPass>(options);
}
diff --git a/mlir/test/Transforms/buffer-results-to-out-params-add-result-attr.mlir b/mlir/test/Transforms/buffer-results-to-out-params-add-result-attr.mlir
new file mode 100644
index 00000000000000..48d5d2372b869e
--- /dev/null
+++ b/mlir/test/Transforms/buffer-results-to-out-params-add-result-attr.mlir
@@ -0,0 +1,18 @@
+// RUN: mlir-opt -p 'builtin.module(buffer-results-to-out-params{add-result-attr})' -split-input-file -verify-diagnostics %s | FileCheck %s
+
+// CHECK-LABEL: basic
+// CHECK-SAME: memref<f32> {bufferize.result})
+func.func @basic() -> (memref<f32>) {
+ %0 = "test.source"() : () -> (memref<f32>)
+ return %0 : memref<f32>
+}
+
+// -----
+
+// CHECK-LABEL: multiple_results
+// CHECK-SAME: memref<1xf32> {bufferize.result},
+// CHECK-SAME: memref<2xf32> {bufferize.result})
+func.func @multiple_results() -> (memref<1xf32>, memref<2xf32>) {
+ %0, %1 = "test.source"() : () -> (memref<1xf32>, memref<2xf32>)
+ return %0, %1 : memref<1xf32>, memref<2xf32>
+}
|
@llvm/pr-subscribers-mlir-bufferization Author: Matthias Gehre (mgehre-amd) ChangesAdds a new pass option To be able to test this, the pass option was added to the tablegen. To avoid collisions with the existing, manually defined option struct Full diff: https://github.com/llvm/llvm-project/pull/84320.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
index 809f03407258a8..a729bc99b987cd 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
@@ -148,7 +148,7 @@ std::unique_ptr<Pass> createBufferLoopHoistingPass();
// Options struct for BufferResultsToOutParams pass.
// Note: defined only here, not in tablegen.
-struct BufferResultsToOutParamsOptions {
+struct BufferResultsToOutParamsOpts {
/// Memcpy function: Generate a memcpy between two memrefs.
using MemCpyFn =
std::function<LogicalResult(OpBuilder &, Location, Value, Value)>;
@@ -162,17 +162,21 @@ struct BufferResultsToOutParamsOptions {
/// Memcpy function; used to create a copy between two memrefs.
/// If this is empty, memref.copy is used.
std::optional<MemCpyFn> memCpyFn;
+
+ /// If true, the pass adds a "bufferize.result" attribute to each output
+ /// parameter.
+ bool addResultAttribute = false;
};
/// Creates a pass that converts memref function results to out-params.
std::unique_ptr<Pass> createBufferResultsToOutParamsPass(
- const BufferResultsToOutParamsOptions &options = {});
+ const BufferResultsToOutParamsOpts &options = {});
/// Replace buffers that are returned from a function with an out parameter.
/// Also update all call sites.
LogicalResult
promoteBufferResultsToOutParams(ModuleOp module,
- const BufferResultsToOutParamsOptions &options);
+ const BufferResultsToOutParamsOpts &options);
/// Creates a pass that drops memref function results that are equivalent to a
/// function argument.
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
index e01f36b8daa18d..1c3cdec81a39e0 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
@@ -316,6 +316,11 @@ def BufferResultsToOutParams : Pass<"buffer-results-to-out-params", "ModuleOp">
buffers for results need to be allocated in the caller. This currently only
works for static shaped memrefs.
}];
+ let options = [
+ Option<"addResultAttribute", "add-result-attr", "bool",
+ /*default=*/"false",
+ "Add the attribute 'bufferize.result' to all output parameters.">,
+ ];
let constructor = "mlir::bufferization::createBufferResultsToOutParamsPass()";
let dependentDialects = ["memref::MemRefDialect"];
}
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
index 930f035339c1d3..5ab347066c90cb 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
@@ -21,7 +21,7 @@ namespace bufferization {
} // namespace mlir
using namespace mlir;
-using MemCpyFn = bufferization::BufferResultsToOutParamsOptions::MemCpyFn;
+using MemCpyFn = bufferization::BufferResultsToOutParamsOpts::MemCpyFn;
/// Return `true` if the given MemRef type has a fully dynamic layout.
static bool hasFullyDynamicLayoutMap(MemRefType type) {
@@ -47,7 +47,8 @@ static bool hasStaticIdentityLayout(MemRefType type) {
// Any args appended to the entry block are added to `appendedEntryArgs`.
static LogicalResult
updateFuncOp(func::FuncOp func,
- SmallVectorImpl<BlockArgument> &appendedEntryArgs) {
+ SmallVectorImpl<BlockArgument> &appendedEntryArgs,
+ bool addResultAttribute) {
auto functionType = func.getFunctionType();
// Collect information about the results will become appended arguments.
@@ -80,6 +81,10 @@ updateFuncOp(func::FuncOp func,
for (int i = 0, e = erasedResultTypes.size(); i < e; ++i, ++erasedIndicesIt) {
func.setArgAttrs(functionType.getNumInputs() + i,
func.getResultAttrs(*erasedIndicesIt));
+ if (addResultAttribute)
+ func.setArgAttr(functionType.getNumInputs() + i,
+ StringAttr::get(func.getContext(), "bufferize.result"),
+ UnitAttr::get(func.getContext()));
}
// Erase the results.
@@ -127,7 +132,7 @@ static LogicalResult updateReturnOps(func::FuncOp func,
// temporary buffers for newly introduced out params.
static LogicalResult
updateCalls(ModuleOp module,
- const bufferization::BufferResultsToOutParamsOptions &options) {
+ const bufferization::BufferResultsToOutParamsOpts &options) {
bool didFail = false;
SymbolTable symtab(module);
module.walk([&](func::CallOp op) {
@@ -189,12 +194,13 @@ updateCalls(ModuleOp module,
LogicalResult mlir::bufferization::promoteBufferResultsToOutParams(
ModuleOp module,
- const bufferization::BufferResultsToOutParamsOptions &options) {
+ const bufferization::BufferResultsToOutParamsOpts &options) {
for (auto func : module.getOps<func::FuncOp>()) {
if (!options.filterFn(&func))
continue;
SmallVector<BlockArgument, 6> appendedEntryArgs;
- if (failed(updateFuncOp(func, appendedEntryArgs)))
+ if (failed(
+ updateFuncOp(func, appendedEntryArgs, options.addResultAttribute)))
return failure();
if (func.isExternal())
continue;
@@ -218,21 +224,25 @@ struct BufferResultsToOutParamsPass
: bufferization::impl::BufferResultsToOutParamsBase<
BufferResultsToOutParamsPass> {
explicit BufferResultsToOutParamsPass(
- const bufferization::BufferResultsToOutParamsOptions &options)
+ const bufferization::BufferResultsToOutParamsOpts &options)
: options(options) {}
void runOnOperation() override {
+ // Convert from pass options in tablegen to BufferResultsToOutParamsOpts.
+ if (addResultAttribute)
+ options.addResultAttribute = true;
+
if (failed(bufferization::promoteBufferResultsToOutParams(getOperation(),
options)))
return signalPassFailure();
}
private:
- bufferization::BufferResultsToOutParamsOptions options;
+ bufferization::BufferResultsToOutParamsOpts options;
};
} // namespace
std::unique_ptr<Pass> mlir::bufferization::createBufferResultsToOutParamsPass(
- const bufferization::BufferResultsToOutParamsOptions &options) {
+ const bufferization::BufferResultsToOutParamsOpts &options) {
return std::make_unique<BufferResultsToOutParamsPass>(options);
}
diff --git a/mlir/test/Transforms/buffer-results-to-out-params-add-result-attr.mlir b/mlir/test/Transforms/buffer-results-to-out-params-add-result-attr.mlir
new file mode 100644
index 00000000000000..48d5d2372b869e
--- /dev/null
+++ b/mlir/test/Transforms/buffer-results-to-out-params-add-result-attr.mlir
@@ -0,0 +1,18 @@
+// RUN: mlir-opt -p 'builtin.module(buffer-results-to-out-params{add-result-attr})' -split-input-file -verify-diagnostics %s | FileCheck %s
+
+// CHECK-LABEL: basic
+// CHECK-SAME: memref<f32> {bufferize.result})
+func.func @basic() -> (memref<f32>) {
+ %0 = "test.source"() : () -> (memref<f32>)
+ return %0 : memref<f32>
+}
+
+// -----
+
+// CHECK-LABEL: multiple_results
+// CHECK-SAME: memref<1xf32> {bufferize.result},
+// CHECK-SAME: memref<2xf32> {bufferize.result})
+func.func @multiple_results() -> (memref<1xf32>, memref<2xf32>) {
+ %0, %1 = "test.source"() : () -> (memref<1xf32>, memref<2xf32>)
+ return %0, %1 : memref<1xf32>, memref<2xf32>
+}
|
|
b0c081b
to
4671565
Compare
mlir/test/Transforms/buffer-results-to-out-params-add-result-attr.mlir
Outdated
Show resolved
Hide resolved
…ribute to output arguments Adds a new pass option `add-result-attr` that will make the pass add the attribute `{bufferize.result}` to each argument that was converted from a result. To be able to test this, the pass option was added to the tablegen. To avoid collisions with the existing, manually defined option struct `BufferResultsToOutParamsOptions`, that one was renamed to `BufferResultsToOutParamsOpts`.
4671565
to
be1b0a1
Compare
Adds a new pass option
add-result-attr
that will make the pass add the attribute{bufferize.result}
to each argument that was converted from a result.This is important e.g. when later using the python bindings / execution engine to understand which arguments are actually results.
To be able to test this, the pass option was added to the tablegen. To avoid collisions with the existing, manually defined option struct
BufferResultsToOutParamsOptions
, that one was renamed toBufferResultsToOutParamsOpts
.