Skip to content

Commit c1047ba

Browse files
authored
[MLIR] Enable pattern only for scf.forall op (llvm#110230)
The init args shape might change in the loop body and hence the pattern doesn't hold true.
1 parent d9cd607 commit c1047ba

File tree

1 file changed

+19
-4
lines changed

1 file changed

+19
-4
lines changed

mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "mlir/Dialect/Arith/Utils/Utils.h"
1919
#include "mlir/Dialect/MemRef/IR/MemRef.h"
2020
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
21+
#include "mlir/Dialect/SCF/IR/SCF.h"
2122
#include "mlir/Dialect/Tensor/IR/Tensor.h"
2223
#include "mlir/Interfaces/InferTypeOpInterface.h"
2324
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -131,11 +132,25 @@ struct IterArgsToInitArgs : public OpRewritePattern<tensor::DimOp> {
131132
auto blockArg = dyn_cast<BlockArgument>(dimOp.getSource());
132133
if (!blockArg)
133134
return failure();
134-
auto loopLikeOp =
135-
dyn_cast<LoopLikeOpInterface>(blockArg.getParentBlock()->getParentOp());
136-
if (!loopLikeOp)
135+
// TODO: Enable this for loopLikeInterface. Restricting for scf.for
136+
// because the init args shape might change in the loop body.
137+
// For e.g.:
138+
// ```
139+
// %0 = tensor.empty(%c1) : tensor<?xf32>
140+
// %r = scf.for %iv = %c0 to %c10 step %c1 iter_args(%arg0 = %0) ->
141+
// tensor<?xf32> {
142+
// %1 = tensor.dim %arg0, %c0 : tensor<?xf32>
143+
// %2 = arith.addi %c1, %1 : index
144+
// %3 = tensor.empty(%2) : tensor<?xf32>
145+
// scf.yield %3 : tensor<?xf32>
146+
// }
147+
//
148+
// ```
149+
auto forAllOp =
150+
dyn_cast<scf::ForallOp>(blockArg.getParentBlock()->getParentOp());
151+
if (!forAllOp)
137152
return failure();
138-
Value initArg = loopLikeOp.getTiedLoopInit(blockArg)->get();
153+
Value initArg = forAllOp.getTiedLoopInit(blockArg)->get();
139154
rewriter.modifyOpInPlace(
140155
dimOp, [&]() { dimOp.getSourceMutable().assign(initArg); });
141156
return success();

0 commit comments

Comments
 (0)