|
18 | 18 | #include "mlir/Dialect/Arith/Utils/Utils.h"
|
19 | 19 | #include "mlir/Dialect/MemRef/IR/MemRef.h"
|
20 | 20 | #include "mlir/Dialect/MemRef/Transforms/Transforms.h"
|
| 21 | +#include "mlir/Dialect/SCF/IR/SCF.h" |
21 | 22 | #include "mlir/Dialect/Tensor/IR/Tensor.h"
|
22 | 23 | #include "mlir/Interfaces/InferTypeOpInterface.h"
|
23 | 24 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
@@ -131,11 +132,25 @@ struct IterArgsToInitArgs : public OpRewritePattern<tensor::DimOp> {
|
131 | 132 | auto blockArg = dyn_cast<BlockArgument>(dimOp.getSource());
|
132 | 133 | if (!blockArg)
|
133 | 134 | return failure();
|
134 |
| - auto loopLikeOp = |
135 |
| - dyn_cast<LoopLikeOpInterface>(blockArg.getParentBlock()->getParentOp()); |
136 |
| - if (!loopLikeOp) |
| 135 | + // TODO: Enable this for loopLikeInterface. Restricting for scf.for |
| 136 | + // because the init args shape might change in the loop body. |
| 137 | + // For e.g.: |
| 138 | + // ``` |
| 139 | + // %0 = tensor.empty(%c1) : tensor<?xf32> |
| 140 | + // %r = scf.for %iv = %c0 to %c10 step %c1 iter_args(%arg0 = %0) -> |
| 141 | + // tensor<?xf32> { |
| 142 | + // %1 = tensor.dim %arg0, %c0 : tensor<?xf32> |
| 143 | + // %2 = arith.addi %c1, %1 : index |
| 144 | + // %3 = tensor.empty(%2) : tensor<?xf32> |
| 145 | + // scf.yield %3 : tensor<?xf32> |
| 146 | + // } |
| 147 | + // |
| 148 | + // ``` |
| 149 | + auto forAllOp = |
| 150 | + dyn_cast<scf::ForallOp>(blockArg.getParentBlock()->getParentOp()); |
| 151 | + if (!forAllOp) |
137 | 152 | return failure();
|
138 |
| - Value initArg = loopLikeOp.getTiedLoopInit(blockArg)->get(); |
| 153 | + Value initArg = forAllOp.getTiedLoopInit(blockArg)->get(); |
139 | 154 | rewriter.modifyOpInPlace(
|
140 | 155 | dimOp, [&]() { dimOp.getSourceMutable().assign(initArg); });
|
141 | 156 | return success();
|
|
0 commit comments