Skip to content

Commit c52cf40

Browse files
committed
add isForOpYieldResultOfInnerLoop check
1 parent f7dd9d8 commit c52cf40

File tree

1 file changed

+22
-3
lines changed

1 file changed

+22
-3
lines changed

mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1251,9 +1251,9 @@ static FailureOr<OpResult> getRealProducerFromExtractSliceOp(
12511251
///
12521252
/// If `%2 = scf.for` is given without specific prediction function, this
12531253
/// function will return three nest loops: %0 + %1 + %2.
1254-
static SmallVector<LoopLikeOpInterface> getOuterNestLoopsWhile(
1255-
LoopLikeOpInterface loop,
1256-
const std::function<LogicalResult(LoopLikeOpInterface)> &pred) {
1254+
static SmallVector<LoopLikeOpInterface>
1255+
getOuterNestLoopsWhile(LoopLikeOpInterface loop,
1256+
function_ref<LogicalResult(LoopLikeOpInterface)> pred) {
12571257
SmallVector<LoopLikeOpInterface> nestLoops = {loop};
12581258
auto outerLoop = dyn_cast<LoopLikeOpInterface>(loop->getParentOp());
12591259
while (outerLoop && succeeded(pred(outerLoop))) {
@@ -1264,6 +1264,21 @@ static SmallVector<LoopLikeOpInterface> getOuterNestLoopsWhile(
12641264
return {nestLoops.rbegin(), nestLoops.rend()};
12651265
}
12661266

1267+
/// Check if it is the ForOp that yield the result of inner loop
1268+
static LogicalResult isForOpYieldResultOfInnerLoop(LoopLikeOpInterface loop) {
1269+
if (auto forOp = dyn_cast<scf::ForOp>(loop.getOperation())) {
1270+
for (auto &&[index, op] :
1271+
llvm::enumerate(forOp.getBody()->getOperations())) {
1272+
// If the orderIndex of inner loop is the last second one before the
1273+
// yieldOp of ForOp, the given loop must yield the result of inner loop.
1274+
if (isa<LoopLikeOpInterface>(op)) {
1275+
return success((index + 2) == forOp.getBody()->getOperations().size());
1276+
}
1277+
}
1278+
}
1279+
return failure();
1280+
}
1281+
12671282
/// Enhanced version for basic implementation of fusing producer, which can deal
12681283
/// with multi-level candidates. E.g.
12691284
///
@@ -1295,6 +1310,10 @@ mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
12951310
// get nest loops between next candidate sliceOp and tiled producer.
12961311
auto whileProducerOutOfLoopBlock =
12971312
[&fuseProducerResult](LoopLikeOpInterface loop) -> LogicalResult {
1313+
// ensure that all surrounding outer loops are just yielding the result of
1314+
// the inner loops.
1315+
if (failed(isForOpYieldResultOfInnerLoop(loop)))
1316+
return failure();
12981317
if (fuseProducerResult) {
12991318
Block &body = loop->getRegion(0).front();
13001319
if (fuseProducerResult->tiledAndFusedProducer.getDefiningOp()

0 commit comments

Comments
 (0)