Skip to content

[mlir][memref] Fix hoist-static-allocs option of buffer-results-to-out-params when function parameters are returned #102093

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Sep 4, 2024

Conversation

Menooker
Copy link
Contributor

@Menooker Menooker commented Aug 6, 2024

buffer-results-to-out-params pass will have a referencing nullptr error when hoist-static-allocs option is on, when the return value of a function is a parameter of the function. This PR fixes this issue.

@llvmbot llvmbot added mlir mlir:bufferization Bufferization infrastructure labels Aug 6, 2024
@llvmbot
Copy link
Member

llvmbot commented Aug 6, 2024

@llvm/pr-subscribers-mlir-bufferization

@llvm/pr-subscribers-mlir

Author: Menooker (Menooker)

Changes

buffer-results-to-out-params pass will have a nullptr error when hoist-static-allocs option is on, when the return value of a function is a parameter of the function. This PR fixes this issue and let the pass remove the return value in the ReturnOp when

  • the value type is memref
  • and the value is a function parameter

Full diff: https://github.com/llvm/llvm-project/pull/102093.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td (+5)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp (+15-2)
  • (modified) mlir/test/Transforms/buffer-results-to-out-params-elim.mlir (+12-1)
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
index 1cece818dbbbc..d6f13b2153828 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
@@ -321,6 +321,11 @@ def BufferResultsToOutParams : Pass<"buffer-results-to-out-params", "ModuleOp">
     This optimization applies on the returned memref which has static shape and
     is allocated by memref.alloc in the function. It will use the memref given
     in function argument to replace the allocated memref.
+
+    If the hoist-static-allocs option is on, and a function returns a memref
+    from the function argument, the pass will avoid the memory-copy from
+    the input function argument to the "out param", and leave the "out param"
+    unused. 
   }];
   let options = [
     Option<"addResultAttribute", "add-result-attr", "bool",
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
index b19636adaa69e..16a42ca779381 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
@@ -102,6 +102,14 @@ updateFuncOp(func::FuncOp func,
   return success();
 }
 
+static bool isFunctionArgument(mlir::Value value) {
+  // Check if the value is a Function argument
+  if (auto blockArg = dyn_cast<mlir::BlockArgument>(value)) {
+    return blockArg.getOwner()->isEntryBlock();
+  }
+  return false;
+}
+
 // Updates all ReturnOps in the scope of the given func::FuncOp by either
 // keeping them as return values or copying the associated buffer contents into
 // the given out-params.
@@ -120,10 +128,15 @@ static LogicalResult updateReturnOps(func::FuncOp func,
     }
     OpBuilder builder(op);
     for (auto [orig, arg] : llvm::zip(copyIntoOutParams, appendedEntryArgs)) {
-      if (hoistStaticAllocs && isa<memref::AllocOp>(orig.getDefiningOp()) &&
-          mlir::cast<MemRefType>(orig.getType()).hasStaticShape()) {
+      bool mayHoistStaticAlloc =
+          hoistStaticAllocs &&
+          mlir::cast<MemRefType>(orig.getType()).hasStaticShape();
+      if (mayHoistStaticAlloc &&
+          isa_and_nonnull<memref::AllocOp>(orig.getDefiningOp())) {
         orig.replaceAllUsesWith(arg);
         orig.getDefiningOp()->erase();
+      } else if (mayHoistStaticAlloc && isFunctionArgument(orig)) {
+        // do nothing but remove the value from the return op.
       } else {
         if (failed(memCpyFn(builder, op.getLoc(), orig, arg)))
           return WalkResult::interrupt();
diff --git a/mlir/test/Transforms/buffer-results-to-out-params-elim.mlir b/mlir/test/Transforms/buffer-results-to-out-params-elim.mlir
index f77dbfaa6cb11..2bd9a9a045531 100644
--- a/mlir/test/Transforms/buffer-results-to-out-params-elim.mlir
+++ b/mlir/test/Transforms/buffer-results-to-out-params-elim.mlir
@@ -34,4 +34,15 @@ func.func @basic_dynamic(%d: index) -> (memref<?xf32>) {
   %b = memref.alloc(%d) : memref<?xf32>
   "test.source"(%b)  : (memref<?xf32>) -> ()
   return %b : memref<?xf32>
-}
\ No newline at end of file
+}
+
+// CHECK-LABEL:   func @return_arg(
+// CHECK-SAME:        %[[ARG0:.*]]: memref<128x256xf32>, %[[ARG1:.*]]: memref<128x256xf32>, %[[ARG2:.*]]: memref<128x256xf32>) {
+// CHECK:           "test.source"(%[[ARG0]], %[[ARG1]])
+// CHECK-NOT:       memref.copy
+// CHECK:           return
+// CHECK:         }
+func.func @return_arg(%arg0: memref<128x256xf32>, %arg1: memref<128x256xf32>) -> memref<128x256xf32> {
+  "test.source"(%arg0, %arg1)  : (memref<128x256xf32>, memref<128x256xf32>) -> ()
+  return %arg0 : memref<128x256xf32>
+}

Copy link
Contributor

@cxy-1993 cxy-1993 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the patch. Looks good to me in general, except some logical errors.

orig.replaceAllUsesWith(arg);
orig.getDefiningOp()->erase();
} else if (mayHoistStaticAlloc && isFunctionArgument(orig)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

e don't need to set the hoistStaticAllocs option to true to do this optimization: the return value doesn't involve any allocations that weren't already present in the function parameters.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It indeed changes the behavior of the program. It is not sure to me what is the expected behavior of this pass without hoistStaticAllocs, if we return the memref from an arg. If we should copy the data from arg to arg?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We dont have to. My point is we dont need hoistStaticAllocs to eliminate copy.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I am not sure if it is ok to avoid the memcpy for this case...

Consider the simple function:

func add1(%0: memref<...>) -> memref<...> {
    my.add(%0)
    return %0
}

This pass will change the signature to:

func add1(%0: memref<...>, %out: memref<...>)  {
 ....
}

If the user of this pass is unaware of the in-place-return behavior, they may pass different buffers for %0 and %out`. The safest way is to copy %0 to %out. It will always meet the expectation of the users.

If we remove the memcpy from %0 to %out, actually %out is an unused parameter, and we should expect the user aware of this in-place-return behavior. If we turn it on by default, will it be a surprise to them?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, this will confuse users. We should avoid doing this.

@@ -321,6 +321,11 @@ def BufferResultsToOutParams : Pass<"buffer-results-to-out-params", "ModuleOp">
This optimization applies on the returned memref which has static shape and
is allocated by memref.alloc in the function. It will use the memref given
in function argument to replace the allocated memref.

If the hoist-static-allocs option is on, and a function returns a memref
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

@Menooker
Copy link
Contributor Author

Hi @cxy-1993 I have updated this PR for the entry-block issue. Would you please have another look? Thanks!

@Menooker
Copy link
Contributor Author

Menooker commented Sep 3, 2024

Hi @cxy-1993 , thanks for your review and comments. I have removed the confusing optimization in this patch. And I am focusing on the bug fix now. Would you please help to review again? Thanks!

Copy link
Contributor

@cxy-1993 cxy-1993 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Please fix test style before merge.

@Menooker Menooker merged commit 26645ae into llvm:main Sep 4, 2024
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:bufferization Bufferization infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants