@@ -1251,9 +1251,9 @@ static FailureOr<OpResult> getRealProducerFromExtractSliceOp(
1251
1251
// /
1252
1252
// / If `%2 = scf.for` is given without specific prediction function, this
1253
1253
// / function will return three nest loops: %0 + %1 + %2.
1254
- static SmallVector<LoopLikeOpInterface> getOuterNestLoopsWhile (
1255
- LoopLikeOpInterface loop,
1256
- const std::function <LogicalResult(LoopLikeOpInterface)> & pred) {
1254
+ static SmallVector<LoopLikeOpInterface>
1255
+ getOuterNestLoopsWhile ( LoopLikeOpInterface loop,
1256
+ function_ref <LogicalResult(LoopLikeOpInterface)> pred) {
1257
1257
SmallVector<LoopLikeOpInterface> nestLoops = {loop};
1258
1258
auto outerLoop = dyn_cast<LoopLikeOpInterface>(loop->getParentOp ());
1259
1259
while (outerLoop && succeeded (pred (outerLoop))) {
@@ -1264,6 +1264,21 @@ static SmallVector<LoopLikeOpInterface> getOuterNestLoopsWhile(
1264
1264
return {nestLoops.rbegin (), nestLoops.rend ()};
1265
1265
}
1266
1266
1267
+ // / Check if it is the ForOp that yield the result of inner loop
1268
+ static LogicalResult isForOpYieldResultOfInnerLoop (LoopLikeOpInterface loop) {
1269
+ if (auto forOp = dyn_cast<scf::ForOp>(loop.getOperation ())) {
1270
+ for (auto &&[index, op] :
1271
+ llvm::enumerate (forOp.getBody ()->getOperations ())) {
1272
+ // If the orderIndex of inner loop is the last second one before the
1273
+ // yieldOp of ForOp, the given loop must yield the result of inner loop.
1274
+ if (isa<LoopLikeOpInterface>(op)) {
1275
+ return success ((index + 2 ) == forOp.getBody ()->getOperations ().size ());
1276
+ }
1277
+ }
1278
+ }
1279
+ return failure ();
1280
+ }
1281
+
1267
1282
// / Enhanced version for basic implementation of fusing producer, which can deal
1268
1283
// / with multi-level candidates. E.g.
1269
1284
// /
@@ -1295,6 +1310,10 @@ mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
1295
1310
// get nest loops between next candidate sliceOp and tiled producer.
1296
1311
auto whileProducerOutOfLoopBlock =
1297
1312
[&fuseProducerResult](LoopLikeOpInterface loop) -> LogicalResult {
1313
+ // ensure that all surrounding outer loops are just yielding the result of
1314
+ // the inner loops.
1315
+ if (failed (isForOpYieldResultOfInnerLoop (loop)))
1316
+ return failure ();
1298
1317
if (fuseProducerResult) {
1299
1318
Block &body = loop->getRegion (0 ).front ();
1300
1319
if (fuseProducerResult->tiledAndFusedProducer .getDefiningOp ()
0 commit comments