@@ -1251,9 +1251,9 @@ static FailureOr<OpResult> getRealProducerFromExtractSliceOp(
1251
1251
// /
1252
1252
// / If `%2 = scf.for` is given without specific prediction function, this
1253
1253
// / function will return three nest loops: %0 + %1 + %2.
1254
- static SmallVector<LoopLikeOpInterface> getOuterNestLoopsWhile (
1255
- LoopLikeOpInterface loop,
1256
- const std::function <LogicalResult(LoopLikeOpInterface)> & pred) {
1254
+ static SmallVector<LoopLikeOpInterface>
1255
+ getOuterNestLoopsWhile ( LoopLikeOpInterface loop,
1256
+ function_ref <LogicalResult(LoopLikeOpInterface)> pred) {
1257
1257
SmallVector<LoopLikeOpInterface> nestLoops = {loop};
1258
1258
auto outerLoop = dyn_cast<LoopLikeOpInterface>(loop->getParentOp ());
1259
1259
while (outerLoop && succeeded (pred (outerLoop))) {
@@ -1264,6 +1264,21 @@ static SmallVector<LoopLikeOpInterface> getOuterNestLoopsWhile(
1264
1264
return {nestLoops.rbegin (), nestLoops.rend ()};
1265
1265
}
1266
1266
1267
+ // / Check if it is the ForOp that yield the result of inner loop
1268
+ static LogicalResult isForOpYieldResultOfInnerLoop (LoopLikeOpInterface loop) {
1269
+ if (auto forOp = dyn_cast<scf::ForOp>(loop.getOperation ())) {
1270
+ Block::OpListType &opsInLoopBody = forOp.getBody ()->getOperations ();
1271
+ for (auto &&[index, op] : llvm::enumerate (opsInLoopBody)) {
1272
+ // If the orderIndex of inner loop is the last second one before the
1273
+ // yieldOp of ForOp, the given loop must yield the result of inner loop.
1274
+ if (isa<LoopLikeOpInterface>(op)) {
1275
+ return success ((index + 2 ) == opsInLoopBody.size ());
1276
+ }
1277
+ }
1278
+ }
1279
+ return failure ();
1280
+ }
1281
+
1267
1282
// / Enhanced version for basic implementation of fusing producer, which can deal
1268
1283
// / with multi-level candidates. E.g.
1269
1284
// /
@@ -1282,10 +1297,10 @@ std::optional<scf::SCFFuseProducerOfSliceResult>
1282
1297
mlir::scf::tileAndFuseProducerOfSlice (RewriterBase &rewriter,
1283
1298
Operation *candidateSliceOp) {
1284
1299
SmallVector<tensor::ExtractSliceOp> backwardSlice;
1285
- if (failed (
1286
- getRealProducerFromExtractSliceOp (candidateSliceOp, backwardSlice))) {
1300
+ FailureOr<OpResult> realProducer =
1301
+ getRealProducerFromExtractSliceOp (candidateSliceOp, backwardSlice);
1302
+ if (failed (realProducer))
1287
1303
return std::nullopt;
1288
- }
1289
1304
1290
1305
std::optional<scf::SCFFuseProducerOfSliceResult> fuseProducerResult;
1291
1306
// reverse from outer to inner
@@ -1294,14 +1309,18 @@ mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
1294
1309
for (auto &&[index, sliceOp] : llvm::enumerate (backwardSlice)) {
1295
1310
// get nest loops between next candidate sliceOp and tiled producer.
1296
1311
auto whileProducerOutOfLoopBlock =
1297
- [&fuseProducerResult](LoopLikeOpInterface loop) -> LogicalResult {
1298
- if (fuseProducerResult) {
1299
- Block &body = loop->getRegion (0 ).front ();
1300
- if (fuseProducerResult->tiledAndFusedProducer .getDefiningOp ()
1301
- ->getBlock () == &body)
1302
- return failure ();
1303
- }
1304
- return success ();
1312
+ [&fuseProducerResult,
1313
+ &realProducer](LoopLikeOpInterface loop) -> LogicalResult {
1314
+ // ensure that all surrounding outer loops are just yielding the result of
1315
+ // the inner loops.
1316
+ if (failed (isForOpYieldResultOfInnerLoop (loop)))
1317
+ return failure ();
1318
+ Operation *originalOp =
1319
+ fuseProducerResult
1320
+ ? fuseProducerResult->tiledAndFusedProducer .getDefiningOp ()
1321
+ : realProducer->getDefiningOp ();
1322
+ Block &body = loop->getRegion (0 ).front ();
1323
+ return success (originalOp->getBlock () != &body);
1305
1324
};
1306
1325
SmallVector<LoopLikeOpInterface> outerLoops =
1307
1326
getOuterNestLoopsWhile (sliceOp->getParentOfType <LoopLikeOpInterface>(),
0 commit comments