Skip to content

[MLIR][Bufferization] BufferResultsToOutParams: Add option to add attribute to output arguments #84320

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 14, 2024

Conversation

mgehre-amd
Copy link
Contributor

Adds a new pass option add-result-attr that will make the pass add the attribute {bufferize.result} to each argument that was converted from a result.
This is important e.g. when later using the python bindings / execution engine to understand which arguments are actually results.

To be able to test this, the pass option was added to the tablegen. To avoid collisions with the existing, manually defined option struct BufferResultsToOutParamsOptions, that one was renamed to BufferResultsToOutParamsOpts.

@llvmbot llvmbot added mlir mlir:bufferization Bufferization infrastructure labels Mar 7, 2024
@llvmbot
Copy link
Member

llvmbot commented Mar 7, 2024

@llvm/pr-subscribers-mlir

Author: Matthias Gehre (mgehre-amd)

Changes

Adds a new pass option add-result-attr that will make the pass add the attribute {bufferize.result} to each argument that was converted from a result.
This is important e.g. when later using the python bindings / execution engine to understand which arguments are actually results.

To be able to test this, the pass option was added to the tablegen. To avoid collisions with the existing, manually defined option struct BufferResultsToOutParamsOptions, that one was renamed to BufferResultsToOutParamsOpts.


Full diff: https://github.com/llvm/llvm-project/pull/84320.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h (+7-3)
  • (modified) mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td (+5)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp (+18-8)
  • (added) mlir/test/Transforms/buffer-results-to-out-params-add-result-attr.mlir (+18)
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
index 809f03407258a8..a729bc99b987cd 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
@@ -148,7 +148,7 @@ std::unique_ptr<Pass> createBufferLoopHoistingPass();
 
 // Options struct for BufferResultsToOutParams pass.
 // Note: defined only here, not in tablegen.
-struct BufferResultsToOutParamsOptions {
+struct BufferResultsToOutParamsOpts {
   /// Memcpy function: Generate a memcpy between two memrefs.
   using MemCpyFn =
       std::function<LogicalResult(OpBuilder &, Location, Value, Value)>;
@@ -162,17 +162,21 @@ struct BufferResultsToOutParamsOptions {
   /// Memcpy function; used to create a copy between two memrefs.
   /// If this is empty, memref.copy is used.
   std::optional<MemCpyFn> memCpyFn;
+
+  /// If true, the pass adds a "bufferize.result" attribute to each output
+  /// parameter.
+  bool addResultAttribute = false;
 };
 
 /// Creates a pass that converts memref function results to out-params.
 std::unique_ptr<Pass> createBufferResultsToOutParamsPass(
-    const BufferResultsToOutParamsOptions &options = {});
+    const BufferResultsToOutParamsOpts &options = {});
 
 /// Replace buffers that are returned from a function with an out parameter.
 /// Also update all call sites.
 LogicalResult
 promoteBufferResultsToOutParams(ModuleOp module,
-                                const BufferResultsToOutParamsOptions &options);
+                                const BufferResultsToOutParamsOpts &options);
 
 /// Creates a pass that drops memref function results that are equivalent to a
 /// function argument.
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
index e01f36b8daa18d..1c3cdec81a39e0 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
@@ -316,6 +316,11 @@ def BufferResultsToOutParams : Pass<"buffer-results-to-out-params", "ModuleOp">
     buffers for results need to be allocated in the caller. This currently only
     works for static shaped memrefs.
   }];
+  let options = [
+    Option<"addResultAttribute", "add-result-attr", "bool",
+       /*default=*/"false",
+       "Add the attribute 'bufferize.result' to all output parameters.">,
+  ];
   let constructor = "mlir::bufferization::createBufferResultsToOutParamsPass()";
   let dependentDialects = ["memref::MemRefDialect"];
 }
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
index 930f035339c1d3..5ab347066c90cb 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
@@ -21,7 +21,7 @@ namespace bufferization {
 } // namespace mlir
 
 using namespace mlir;
-using MemCpyFn = bufferization::BufferResultsToOutParamsOptions::MemCpyFn;
+using MemCpyFn = bufferization::BufferResultsToOutParamsOpts::MemCpyFn;
 
 /// Return `true` if the given MemRef type has a fully dynamic layout.
 static bool hasFullyDynamicLayoutMap(MemRefType type) {
@@ -47,7 +47,8 @@ static bool hasStaticIdentityLayout(MemRefType type) {
 // Any args appended to the entry block are added to `appendedEntryArgs`.
 static LogicalResult
 updateFuncOp(func::FuncOp func,
-             SmallVectorImpl<BlockArgument> &appendedEntryArgs) {
+             SmallVectorImpl<BlockArgument> &appendedEntryArgs,
+             bool addResultAttribute) {
   auto functionType = func.getFunctionType();
 
   // Collect information about the results will become appended arguments.
@@ -80,6 +81,10 @@ updateFuncOp(func::FuncOp func,
   for (int i = 0, e = erasedResultTypes.size(); i < e; ++i, ++erasedIndicesIt) {
     func.setArgAttrs(functionType.getNumInputs() + i,
                      func.getResultAttrs(*erasedIndicesIt));
+    if (addResultAttribute)
+      func.setArgAttr(functionType.getNumInputs() + i,
+                      StringAttr::get(func.getContext(), "bufferize.result"),
+                      UnitAttr::get(func.getContext()));
   }
 
   // Erase the results.
@@ -127,7 +132,7 @@ static LogicalResult updateReturnOps(func::FuncOp func,
 // temporary buffers for newly introduced out params.
 static LogicalResult
 updateCalls(ModuleOp module,
-            const bufferization::BufferResultsToOutParamsOptions &options) {
+            const bufferization::BufferResultsToOutParamsOpts &options) {
   bool didFail = false;
   SymbolTable symtab(module);
   module.walk([&](func::CallOp op) {
@@ -189,12 +194,13 @@ updateCalls(ModuleOp module,
 
 LogicalResult mlir::bufferization::promoteBufferResultsToOutParams(
     ModuleOp module,
-    const bufferization::BufferResultsToOutParamsOptions &options) {
+    const bufferization::BufferResultsToOutParamsOpts &options) {
   for (auto func : module.getOps<func::FuncOp>()) {
     if (!options.filterFn(&func))
       continue;
     SmallVector<BlockArgument, 6> appendedEntryArgs;
-    if (failed(updateFuncOp(func, appendedEntryArgs)))
+    if (failed(
+            updateFuncOp(func, appendedEntryArgs, options.addResultAttribute)))
       return failure();
     if (func.isExternal())
       continue;
@@ -218,21 +224,25 @@ struct BufferResultsToOutParamsPass
     : bufferization::impl::BufferResultsToOutParamsBase<
           BufferResultsToOutParamsPass> {
   explicit BufferResultsToOutParamsPass(
-      const bufferization::BufferResultsToOutParamsOptions &options)
+      const bufferization::BufferResultsToOutParamsOpts &options)
       : options(options) {}
 
   void runOnOperation() override {
+    // Convert from pass options in tablegen to BufferResultsToOutParamsOpts.
+    if (addResultAttribute)
+      options.addResultAttribute = true;
+
     if (failed(bufferization::promoteBufferResultsToOutParams(getOperation(),
                                                               options)))
       return signalPassFailure();
   }
 
 private:
-  bufferization::BufferResultsToOutParamsOptions options;
+  bufferization::BufferResultsToOutParamsOpts options;
 };
 } // namespace
 
 std::unique_ptr<Pass> mlir::bufferization::createBufferResultsToOutParamsPass(
-    const bufferization::BufferResultsToOutParamsOptions &options) {
+    const bufferization::BufferResultsToOutParamsOpts &options) {
   return std::make_unique<BufferResultsToOutParamsPass>(options);
 }
diff --git a/mlir/test/Transforms/buffer-results-to-out-params-add-result-attr.mlir b/mlir/test/Transforms/buffer-results-to-out-params-add-result-attr.mlir
new file mode 100644
index 00000000000000..48d5d2372b869e
--- /dev/null
+++ b/mlir/test/Transforms/buffer-results-to-out-params-add-result-attr.mlir
@@ -0,0 +1,18 @@
+// RUN: mlir-opt -p 'builtin.module(buffer-results-to-out-params{add-result-attr})' -split-input-file -verify-diagnostics %s | FileCheck %s
+
+// CHECK-LABEL: basic
+// CHECK-SAME:  memref<f32> {bufferize.result})
+func.func @basic() -> (memref<f32>) {
+  %0 = "test.source"() : () -> (memref<f32>)
+  return %0 : memref<f32>
+}
+
+// -----
+
+// CHECK-LABEL: multiple_results
+// CHECK-SAME:  memref<1xf32> {bufferize.result},
+// CHECK-SAME:  memref<2xf32> {bufferize.result})
+func.func @multiple_results() -> (memref<1xf32>, memref<2xf32>) {
+  %0, %1 = "test.source"() : () -> (memref<1xf32>, memref<2xf32>)
+  return %0, %1 : memref<1xf32>, memref<2xf32>
+}

@llvmbot
Copy link
Member

llvmbot commented Mar 7, 2024

@llvm/pr-subscribers-mlir-bufferization

Author: Matthias Gehre (mgehre-amd)

Changes

Adds a new pass option add-result-attr that will make the pass add the attribute {bufferize.result} to each argument that was converted from a result.
This is important e.g. when later using the python bindings / execution engine to understand which arguments are actually results.

To be able to test this, the pass option was added to the tablegen. To avoid collisions with the existing, manually defined option struct BufferResultsToOutParamsOptions, that one was renamed to BufferResultsToOutParamsOpts.


Full diff: https://github.com/llvm/llvm-project/pull/84320.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h (+7-3)
  • (modified) mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td (+5)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp (+18-8)
  • (added) mlir/test/Transforms/buffer-results-to-out-params-add-result-attr.mlir (+18)
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
index 809f03407258a8..a729bc99b987cd 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
@@ -148,7 +148,7 @@ std::unique_ptr<Pass> createBufferLoopHoistingPass();
 
 // Options struct for BufferResultsToOutParams pass.
 // Note: defined only here, not in tablegen.
-struct BufferResultsToOutParamsOptions {
+struct BufferResultsToOutParamsOpts {
   /// Memcpy function: Generate a memcpy between two memrefs.
   using MemCpyFn =
       std::function<LogicalResult(OpBuilder &, Location, Value, Value)>;
@@ -162,17 +162,21 @@ struct BufferResultsToOutParamsOptions {
   /// Memcpy function; used to create a copy between two memrefs.
   /// If this is empty, memref.copy is used.
   std::optional<MemCpyFn> memCpyFn;
+
+  /// If true, the pass adds a "bufferize.result" attribute to each output
+  /// parameter.
+  bool addResultAttribute = false;
 };
 
 /// Creates a pass that converts memref function results to out-params.
 std::unique_ptr<Pass> createBufferResultsToOutParamsPass(
-    const BufferResultsToOutParamsOptions &options = {});
+    const BufferResultsToOutParamsOpts &options = {});
 
 /// Replace buffers that are returned from a function with an out parameter.
 /// Also update all call sites.
 LogicalResult
 promoteBufferResultsToOutParams(ModuleOp module,
-                                const BufferResultsToOutParamsOptions &options);
+                                const BufferResultsToOutParamsOpts &options);
 
 /// Creates a pass that drops memref function results that are equivalent to a
 /// function argument.
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
index e01f36b8daa18d..1c3cdec81a39e0 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
@@ -316,6 +316,11 @@ def BufferResultsToOutParams : Pass<"buffer-results-to-out-params", "ModuleOp">
     buffers for results need to be allocated in the caller. This currently only
     works for static shaped memrefs.
   }];
+  let options = [
+    Option<"addResultAttribute", "add-result-attr", "bool",
+       /*default=*/"false",
+       "Add the attribute 'bufferize.result' to all output parameters.">,
+  ];
   let constructor = "mlir::bufferization::createBufferResultsToOutParamsPass()";
   let dependentDialects = ["memref::MemRefDialect"];
 }
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
index 930f035339c1d3..5ab347066c90cb 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
@@ -21,7 +21,7 @@ namespace bufferization {
 } // namespace mlir
 
 using namespace mlir;
-using MemCpyFn = bufferization::BufferResultsToOutParamsOptions::MemCpyFn;
+using MemCpyFn = bufferization::BufferResultsToOutParamsOpts::MemCpyFn;
 
 /// Return `true` if the given MemRef type has a fully dynamic layout.
 static bool hasFullyDynamicLayoutMap(MemRefType type) {
@@ -47,7 +47,8 @@ static bool hasStaticIdentityLayout(MemRefType type) {
 // Any args appended to the entry block are added to `appendedEntryArgs`.
 static LogicalResult
 updateFuncOp(func::FuncOp func,
-             SmallVectorImpl<BlockArgument> &appendedEntryArgs) {
+             SmallVectorImpl<BlockArgument> &appendedEntryArgs,
+             bool addResultAttribute) {
   auto functionType = func.getFunctionType();
 
   // Collect information about the results will become appended arguments.
@@ -80,6 +81,10 @@ updateFuncOp(func::FuncOp func,
   for (int i = 0, e = erasedResultTypes.size(); i < e; ++i, ++erasedIndicesIt) {
     func.setArgAttrs(functionType.getNumInputs() + i,
                      func.getResultAttrs(*erasedIndicesIt));
+    if (addResultAttribute)
+      func.setArgAttr(functionType.getNumInputs() + i,
+                      StringAttr::get(func.getContext(), "bufferize.result"),
+                      UnitAttr::get(func.getContext()));
   }
 
   // Erase the results.
@@ -127,7 +132,7 @@ static LogicalResult updateReturnOps(func::FuncOp func,
 // temporary buffers for newly introduced out params.
 static LogicalResult
 updateCalls(ModuleOp module,
-            const bufferization::BufferResultsToOutParamsOptions &options) {
+            const bufferization::BufferResultsToOutParamsOpts &options) {
   bool didFail = false;
   SymbolTable symtab(module);
   module.walk([&](func::CallOp op) {
@@ -189,12 +194,13 @@ updateCalls(ModuleOp module,
 
 LogicalResult mlir::bufferization::promoteBufferResultsToOutParams(
     ModuleOp module,
-    const bufferization::BufferResultsToOutParamsOptions &options) {
+    const bufferization::BufferResultsToOutParamsOpts &options) {
   for (auto func : module.getOps<func::FuncOp>()) {
     if (!options.filterFn(&func))
       continue;
     SmallVector<BlockArgument, 6> appendedEntryArgs;
-    if (failed(updateFuncOp(func, appendedEntryArgs)))
+    if (failed(
+            updateFuncOp(func, appendedEntryArgs, options.addResultAttribute)))
       return failure();
     if (func.isExternal())
       continue;
@@ -218,21 +224,25 @@ struct BufferResultsToOutParamsPass
     : bufferization::impl::BufferResultsToOutParamsBase<
           BufferResultsToOutParamsPass> {
   explicit BufferResultsToOutParamsPass(
-      const bufferization::BufferResultsToOutParamsOptions &options)
+      const bufferization::BufferResultsToOutParamsOpts &options)
       : options(options) {}
 
   void runOnOperation() override {
+    // Convert from pass options in tablegen to BufferResultsToOutParamsOpts.
+    if (addResultAttribute)
+      options.addResultAttribute = true;
+
     if (failed(bufferization::promoteBufferResultsToOutParams(getOperation(),
                                                               options)))
       return signalPassFailure();
   }
 
 private:
-  bufferization::BufferResultsToOutParamsOptions options;
+  bufferization::BufferResultsToOutParamsOpts options;
 };
 } // namespace
 
 std::unique_ptr<Pass> mlir::bufferization::createBufferResultsToOutParamsPass(
-    const bufferization::BufferResultsToOutParamsOptions &options) {
+    const bufferization::BufferResultsToOutParamsOpts &options) {
   return std::make_unique<BufferResultsToOutParamsPass>(options);
 }
diff --git a/mlir/test/Transforms/buffer-results-to-out-params-add-result-attr.mlir b/mlir/test/Transforms/buffer-results-to-out-params-add-result-attr.mlir
new file mode 100644
index 00000000000000..48d5d2372b869e
--- /dev/null
+++ b/mlir/test/Transforms/buffer-results-to-out-params-add-result-attr.mlir
@@ -0,0 +1,18 @@
+// RUN: mlir-opt -p 'builtin.module(buffer-results-to-out-params{add-result-attr})' -split-input-file -verify-diagnostics %s | FileCheck %s
+
+// CHECK-LABEL: basic
+// CHECK-SAME:  memref<f32> {bufferize.result})
+func.func @basic() -> (memref<f32>) {
+  %0 = "test.source"() : () -> (memref<f32>)
+  return %0 : memref<f32>
+}
+
+// -----
+
+// CHECK-LABEL: multiple_results
+// CHECK-SAME:  memref<1xf32> {bufferize.result},
+// CHECK-SAME:  memref<2xf32> {bufferize.result})
+func.func @multiple_results() -> (memref<1xf32>, memref<2xf32>) {
+  %0, %1 = "test.source"() : () -> (memref<1xf32>, memref<2xf32>)
+  return %0, %1 : memref<1xf32>, memref<2xf32>
+}

Copy link

github-actions bot commented Mar 7, 2024

⚠️ We detected that you are using a GitHub private e-mail address to contribute to the repo.
Please turn off Keep my email addresses private setting in your account.
See LLVM Discourse for more information.

@mgehre-amd mgehre-amd force-pushed the matthias.output_attr branch from b0c081b to 4671565 Compare March 7, 2024 13:44
…ribute to output arguments

Adds a new pass option `add-result-attr` that will make the pass add the attribute
`{bufferize.result}` to each argument that was converted from a result.

To be able to test this, the pass option was added to the tablegen.
To avoid collisions with the existing, manually defined option struct
`BufferResultsToOutParamsOptions`, that one was renamed to
`BufferResultsToOutParamsOpts`.
@mgehre-amd mgehre-amd force-pushed the matthias.output_attr branch from 4671565 to be1b0a1 Compare March 14, 2024 06:50
@mgehre-amd mgehre-amd merged commit e6048b7 into llvm:main Mar 14, 2024
@mgehre-amd mgehre-amd deleted the matthias.output_attr branch March 14, 2024 06:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:bufferization Bufferization infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants