Skip to content

Commit 23796bf

Browse files
committed
add isForOpYieldResultOfInnerLoop check
1 parent f7dd9d8 commit 23796bf

File tree

1 file changed

+33
-14
lines changed

1 file changed

+33
-14
lines changed

mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp

Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1251,9 +1251,9 @@ static FailureOr<OpResult> getRealProducerFromExtractSliceOp(
12511251
///
12521252
/// If `%2 = scf.for` is given without specific prediction function, this
12531253
/// function will return three nest loops: %0 + %1 + %2.
1254-
static SmallVector<LoopLikeOpInterface> getOuterNestLoopsWhile(
1255-
LoopLikeOpInterface loop,
1256-
const std::function<LogicalResult(LoopLikeOpInterface)> &pred) {
1254+
static SmallVector<LoopLikeOpInterface>
1255+
getOuterNestLoopsWhile(LoopLikeOpInterface loop,
1256+
function_ref<LogicalResult(LoopLikeOpInterface)> pred) {
12571257
SmallVector<LoopLikeOpInterface> nestLoops = {loop};
12581258
auto outerLoop = dyn_cast<LoopLikeOpInterface>(loop->getParentOp());
12591259
while (outerLoop && succeeded(pred(outerLoop))) {
@@ -1264,6 +1264,21 @@ static SmallVector<LoopLikeOpInterface> getOuterNestLoopsWhile(
12641264
return {nestLoops.rbegin(), nestLoops.rend()};
12651265
}
12661266

1267+
/// Check if it is the ForOp that yield the result of inner loop
1268+
static LogicalResult isForOpYieldResultOfInnerLoop(LoopLikeOpInterface loop) {
1269+
if (auto forOp = dyn_cast<scf::ForOp>(loop.getOperation())) {
1270+
Block::OpListType &opsInLoopBody = forOp.getBody()->getOperations();
1271+
for (auto &&[index, op] : llvm::enumerate(opsInLoopBody)) {
1272+
// If the orderIndex of inner loop is the last second one before the
1273+
// yieldOp of ForOp, the given loop must yield the result of inner loop.
1274+
if (isa<LoopLikeOpInterface>(op)) {
1275+
return success((index + 2) == opsInLoopBody.size());
1276+
}
1277+
}
1278+
}
1279+
return failure();
1280+
}
1281+
12671282
/// Enhanced version for basic implementation of fusing producer, which can deal
12681283
/// with multi-level candidates. E.g.
12691284
///
@@ -1282,10 +1297,10 @@ std::optional<scf::SCFFuseProducerOfSliceResult>
12821297
mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
12831298
Operation *candidateSliceOp) {
12841299
SmallVector<tensor::ExtractSliceOp> backwardSlice;
1285-
if (failed(
1286-
getRealProducerFromExtractSliceOp(candidateSliceOp, backwardSlice))) {
1300+
FailureOr<OpResult> realProducer =
1301+
getRealProducerFromExtractSliceOp(candidateSliceOp, backwardSlice);
1302+
if (failed(realProducer))
12871303
return std::nullopt;
1288-
}
12891304

12901305
std::optional<scf::SCFFuseProducerOfSliceResult> fuseProducerResult;
12911306
// reverse from outer to inner
@@ -1294,14 +1309,18 @@ mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
12941309
for (auto &&[index, sliceOp] : llvm::enumerate(backwardSlice)) {
12951310
// get nest loops between next candidate sliceOp and tiled producer.
12961311
auto whileProducerOutOfLoopBlock =
1297-
[&fuseProducerResult](LoopLikeOpInterface loop) -> LogicalResult {
1298-
if (fuseProducerResult) {
1299-
Block &body = loop->getRegion(0).front();
1300-
if (fuseProducerResult->tiledAndFusedProducer.getDefiningOp()
1301-
->getBlock() == &body)
1302-
return failure();
1303-
}
1304-
return success();
1312+
[&fuseProducerResult,
1313+
&realProducer](LoopLikeOpInterface loop) -> LogicalResult {
1314+
// ensure that all surrounding outer loops are just yielding the result of
1315+
// the inner loops.
1316+
if (failed(isForOpYieldResultOfInnerLoop(loop)))
1317+
return failure();
1318+
Operation *originalOp =
1319+
fuseProducerResult
1320+
? fuseProducerResult->tiledAndFusedProducer.getDefiningOp()
1321+
: realProducer->getDefiningOp();
1322+
Block &body = loop->getRegion(0).front();
1323+
return success(originalOp->getBlock() != &body);
13051324
};
13061325
SmallVector<LoopLikeOpInterface> outerLoops =
13071326
getOuterNestLoopsWhile(sliceOp->getParentOfType<LoopLikeOpInterface>(),

0 commit comments

Comments
 (0)