Skip to content

[MLIR] Enable pattern only for scf.forall op #110230

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 17, 2024
Merged

Conversation

pashu123
Copy link
Member

The init args shape might change in the loop body and hence the pattern doesn't hold true.

@llvmbot
Copy link
Member

llvmbot commented Sep 27, 2024

@llvm/pr-subscribers-mlir

Author: Prashant Kumar (pashu123)

Changes

The init args shape might change in the loop body and hence the pattern doesn't hold true.


Full diff: https://github.com/llvm/llvm-project/pull/110230.diff

1 Files Affected:

  • (modified) mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp (+18-4)
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
index fb2921fec9f79d..aea26602dfb7a4 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
@@ -15,6 +15,7 @@
 
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/MemRef/Transforms/Transforms.h"
@@ -131,11 +132,24 @@ struct IterArgsToInitArgs : public OpRewritePattern<tensor::DimOp> {
     auto blockArg = dyn_cast<BlockArgument>(dimOp.getSource());
     if (!blockArg)
       return failure();
-    auto loopLikeOp =
-        dyn_cast<LoopLikeOpInterface>(blockArg.getParentBlock()->getParentOp());
-    if (!loopLikeOp)
+    // TODO: Enable this for loopLikeInterface. Restricting for scf.for
+    // because the init args shape might change in the loop body.
+    // For e.g.:
+    // ```
+    //  %0 = tensor.empty(%c1) : tensor<?xf32>
+    //  %r = scf.for %iv = %c0 to %c10 step %c1 iter_args(%arg0 = %0) -> tensor<?xf32> {
+    //    %1 = tensor.dim %arg0, %c0 : tensor<?xf32>
+    //    %2 = arith.addi %c1, %1 : index
+    //    %3 = tensor.empty(%2) : tensor<?xf32>
+    //    scf.yield %3 : tensor<?xf32>
+    //  }
+    //
+    // ```
+    auto forAllOp =
+        dyn_cast<scf::ForallOp>(blockArg.getParentBlock()->getParentOp());
+    if (!forAllOp)
       return failure();
-    Value initArg = loopLikeOp.getTiedLoopInit(blockArg)->get();
+    Value initArg = forAllOp.getTiedLoopInit(blockArg)->get();
     rewriter.modifyOpInPlace(
         dimOp, [&]() { dimOp.getSourceMutable().assign(initArg); });
     return success();

@llvmbot
Copy link
Member

llvmbot commented Sep 27, 2024

@llvm/pr-subscribers-mlir-memref

Author: Prashant Kumar (pashu123)

Changes

The init args shape might change in the loop body and hence the pattern doesn't hold true.


Full diff: https://github.com/llvm/llvm-project/pull/110230.diff

1 Files Affected:

  • (modified) mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp (+18-4)
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
index fb2921fec9f79d..aea26602dfb7a4 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
@@ -15,6 +15,7 @@
 
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/MemRef/Transforms/Transforms.h"
@@ -131,11 +132,24 @@ struct IterArgsToInitArgs : public OpRewritePattern<tensor::DimOp> {
     auto blockArg = dyn_cast<BlockArgument>(dimOp.getSource());
     if (!blockArg)
       return failure();
-    auto loopLikeOp =
-        dyn_cast<LoopLikeOpInterface>(blockArg.getParentBlock()->getParentOp());
-    if (!loopLikeOp)
+    // TODO: Enable this for loopLikeInterface. Restricting for scf.for
+    // because the init args shape might change in the loop body.
+    // For e.g.:
+    // ```
+    //  %0 = tensor.empty(%c1) : tensor<?xf32>
+    //  %r = scf.for %iv = %c0 to %c10 step %c1 iter_args(%arg0 = %0) -> tensor<?xf32> {
+    //    %1 = tensor.dim %arg0, %c0 : tensor<?xf32>
+    //    %2 = arith.addi %c1, %1 : index
+    //    %3 = tensor.empty(%2) : tensor<?xf32>
+    //    scf.yield %3 : tensor<?xf32>
+    //  }
+    //
+    // ```
+    auto forAllOp =
+        dyn_cast<scf::ForallOp>(blockArg.getParentBlock()->getParentOp());
+    if (!forAllOp)
       return failure();
-    Value initArg = loopLikeOp.getTiedLoopInit(blockArg)->get();
+    Value initArg = forAllOp.getTiedLoopInit(blockArg)->get();
     rewriter.modifyOpInPlace(
         dimOp, [&]() { dimOp.getSourceMutable().assign(initArg); });
     return success();

The init args shape might change in the loop body and hence the pattern
doesn't hold true.
Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Scoping it to forall ops looks okay to me.

auto loopLikeOp =
dyn_cast<LoopLikeOpInterface>(blockArg.getParentBlock()->getParentOp());
if (!loopLikeOp)
// TODO: Enable this for loopLikeInterface. Restricting for scf.for
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would make sense to move this pattern to the same file.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that pattern is wrong for the reason why this is being restricted to forall only. The init arg shape can vary from iteration to iteration. There it is added as a canonicalization which is even worse.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This pattern has an additional isShapePreserving check that makes it safe. Btw, that check is quite conservative, it could be improved by checking for DestinationStyleOpInterface instead of hard-coding a few ops in the analysis.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@MaheshRavishankar @matthias-springer I'll be moving the pattern from LoopCanonicalization to resolveShapedTypeResultDims and updating that to process scf.forall. That sounds reasonable to me.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll be merging this and refactoring in upcoming PRs.

@pashu123 pashu123 merged commit c1047ba into llvm:main Oct 17, 2024
8 checks passed
Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait this needs a test as well. Please add a test as a follow up.

@hanhanW
Copy link
Contributor

hanhanW commented Oct 17, 2024

Wait this needs a test as well. Please add a test as a follow up.

The test was added in another commit: 0cf8447 I think this is a follow-up for comments from @matthias-springer

@joker-eph
Copy link
Collaborator

This PR is changing some behavior, so the test from 0cf8447 is unlikely to be testing the specific behavior change here.

What you're likely missing is a negative test showing that the pattern does not apply to other ops than scf.forall

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants