-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR] Enable pattern only for scf.forall op #110230
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir Author: Prashant Kumar (pashu123) ChangesThe init args shape might change in the loop body and hence the pattern doesn't hold true. Full diff: https://github.com/llvm/llvm-project/pull/110230.diff 1 Files Affected:
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
index fb2921fec9f79d..aea26602dfb7a4 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
@@ -15,6 +15,7 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
@@ -131,11 +132,24 @@ struct IterArgsToInitArgs : public OpRewritePattern<tensor::DimOp> {
auto blockArg = dyn_cast<BlockArgument>(dimOp.getSource());
if (!blockArg)
return failure();
- auto loopLikeOp =
- dyn_cast<LoopLikeOpInterface>(blockArg.getParentBlock()->getParentOp());
- if (!loopLikeOp)
+ // TODO: Enable this for loopLikeInterface. Restricting for scf.for
+ // because the init args shape might change in the loop body.
+ // For e.g.:
+ // ```
+ // %0 = tensor.empty(%c1) : tensor<?xf32>
+ // %r = scf.for %iv = %c0 to %c10 step %c1 iter_args(%arg0 = %0) -> tensor<?xf32> {
+ // %1 = tensor.dim %arg0, %c0 : tensor<?xf32>
+ // %2 = arith.addi %c1, %1 : index
+ // %3 = tensor.empty(%2) : tensor<?xf32>
+ // scf.yield %3 : tensor<?xf32>
+ // }
+ //
+ // ```
+ auto forAllOp =
+ dyn_cast<scf::ForallOp>(blockArg.getParentBlock()->getParentOp());
+ if (!forAllOp)
return failure();
- Value initArg = loopLikeOp.getTiedLoopInit(blockArg)->get();
+ Value initArg = forAllOp.getTiedLoopInit(blockArg)->get();
rewriter.modifyOpInPlace(
dimOp, [&]() { dimOp.getSourceMutable().assign(initArg); });
return success();
|
@llvm/pr-subscribers-mlir-memref Author: Prashant Kumar (pashu123) ChangesThe init args shape might change in the loop body and hence the pattern doesn't hold true. Full diff: https://github.com/llvm/llvm-project/pull/110230.diff 1 Files Affected:
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
index fb2921fec9f79d..aea26602dfb7a4 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
@@ -15,6 +15,7 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
@@ -131,11 +132,24 @@ struct IterArgsToInitArgs : public OpRewritePattern<tensor::DimOp> {
auto blockArg = dyn_cast<BlockArgument>(dimOp.getSource());
if (!blockArg)
return failure();
- auto loopLikeOp =
- dyn_cast<LoopLikeOpInterface>(blockArg.getParentBlock()->getParentOp());
- if (!loopLikeOp)
+ // TODO: Enable this for loopLikeInterface. Restricting for scf.for
+ // because the init args shape might change in the loop body.
+ // For e.g.:
+ // ```
+ // %0 = tensor.empty(%c1) : tensor<?xf32>
+ // %r = scf.for %iv = %c0 to %c10 step %c1 iter_args(%arg0 = %0) -> tensor<?xf32> {
+ // %1 = tensor.dim %arg0, %c0 : tensor<?xf32>
+ // %2 = arith.addi %c1, %1 : index
+ // %3 = tensor.empty(%2) : tensor<?xf32>
+ // scf.yield %3 : tensor<?xf32>
+ // }
+ //
+ // ```
+ auto forAllOp =
+ dyn_cast<scf::ForallOp>(blockArg.getParentBlock()->getParentOp());
+ if (!forAllOp)
return failure();
- Value initArg = loopLikeOp.getTiedLoopInit(blockArg)->get();
+ Value initArg = forAllOp.getTiedLoopInit(blockArg)->get();
rewriter.modifyOpInPlace(
dimOp, [&]() { dimOp.getSourceMutable().assign(initArg); });
return success();
|
The init args shape might change in the loop body and hence the pattern doesn't hold true.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Scoping it to forall ops looks okay to me.
auto loopLikeOp = | ||
dyn_cast<LoopLikeOpInterface>(blockArg.getParentBlock()->getParentOp()); | ||
if (!loopLikeOp) | ||
// TODO: Enable this for loopLikeInterface. Restricting for scf.for |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fyi, for scf.for
we already have this pattern here: https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp#L87
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it would make sense to move this pattern to the same file.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that pattern is wrong for the reason why this is being restricted to forall only. The init arg shape can vary from iteration to iteration. There it is added as a canonicalization which is even worse.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This pattern has an additional isShapePreserving
check that makes it safe. Btw, that check is quite conservative, it could be improved by checking for DestinationStyleOpInterface
instead of hard-coding a few ops in the analysis.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@MaheshRavishankar @matthias-springer I'll be moving the pattern from LoopCanonicalization to resolveShapedTypeResultDims and updating that to process scf.forall. That sounds reasonable to me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll be merging this and refactoring in upcoming PRs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wait this needs a test as well. Please add a test as a follow up.
The test was added in another commit: 0cf8447 I think this is a follow-up for comments from @matthias-springer |
This PR is changing some behavior, so the test from 0cf8447 is unlikely to be testing the specific behavior change here. What you're likely missing is a negative test showing that the pattern does not apply to other ops than scf.forall |
The init args shape might change in the loop body and hence the pattern doesn't hold true.