Skip to content

Commit ec9640c

Browse files
committed
extend consumer fuse to nested scf loop (v2)
1 parent 8aea5cc commit ec9640c

File tree

3 files changed

+310
-23
lines changed

3 files changed

+310
-23
lines changed

mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,10 @@ struct SCFFuseConsumerOfSliceResult {
250250
*tiledAndFusedConsumerOperand; // Tiled and fused consumer's operand.
251251
SmallVector<Operation *> tiledOps;
252252
};
253+
FailureOr<scf::SCFFuseConsumerOfSliceResult>
254+
tileAndFuseConsumerOfSliceImpl(RewriterBase &rewriter,
255+
Operation *candidateSliceOp);
256+
253257
FailureOr<scf::SCFFuseConsumerOfSliceResult>
254258
tileAndFuseConsumerOfSlice(RewriterBase &rewriter, Operation *candidateSliceOp);
255259

mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp

Lines changed: 210 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1147,12 +1147,24 @@ checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp) {
11471147
/// failure otherwise.
11481148
static FailureOr<OpOperand *> getConsumerFromUses(Value val,
11491149
Block *containingOpBlock) {
1150-
// Step 1. Check that the value has exactly one use.
1151-
if (!llvm::hasSingleElement(val.getUses()))
1150+
// Step 1. Check that the value has exactly one use except for scf.yield.
1151+
OpOperand *operand = nullptr;
1152+
for (auto &use : val.getUses()) {
1153+
Operation *user = use.getOwner();
1154+
if (isa<tensor::InsertSliceOp>(user) ||
1155+
isa<tensor::ParallelInsertSliceOp>(user))
1156+
continue;
1157+
else {
1158+
if (operand)
1159+
return failure();
1160+
else
1161+
operand = &use;
1162+
}
1163+
}
1164+
if (!operand)
11521165
return failure();
11531166
// Step 2. Get uses.
1154-
OpOperand &operand = (*val.getUses().begin());
1155-
Operation *consumerOp = operand.getOwner();
1167+
Operation *consumerOp = operand->getOwner();
11561168
// TODO: We have to init result of consumer before scf.for, use
11571169
// DestinationStyleOpInterface to get result shape from init for now.
11581170
// Add support for other op such as op has InferTypeOpInterface.
@@ -1161,7 +1173,22 @@ static FailureOr<OpOperand *> getConsumerFromUses(Value val,
11611173
return failure();
11621174
if (containingOpBlock != consumerOp->getBlock())
11631175
return failure();
1164-
return &operand;
1176+
return operand;
1177+
}
1178+
1179+
/// Return perfectly outer loops of given ForOp(included), sorted from
1180+
/// outer to inner.
1181+
static SmallVector<scf::ForOp> getPerfectlyOuterLoops(scf::ForOp loop) {
1182+
SmallVector<scf::ForOp> outerLoops = {loop};
1183+
auto forOp = loop->getParentOfType<scf::ForOp>();
1184+
while (forOp) {
1185+
Block &body = forOp.getRegion().front();
1186+
if (body.begin() != std::prev(body.end(), 2))
1187+
break;
1188+
outerLoops.push_back(forOp);
1189+
forOp = forOp->getParentOfType<scf::ForOp>();
1190+
}
1191+
return {outerLoops.rbegin(), outerLoops.rend()};
11651192
}
11661193

11671194
/// Fetch the untiled consumer of a scf.for's result which is yielded by a
@@ -1181,9 +1208,10 @@ getUntiledConsumerFromSlice(tensor::InsertSliceOp candidateSliceOp) {
11811208
auto forOp = dyn_cast<scf::ForOp>(containingOp);
11821209
if (!forOp)
11831210
return failure();
1184-
Value resultingValue = forOp->getResult(resultNumber);
1211+
scf::ForOp topLevelForOp = getPerfectlyOuterLoops(forOp).front();
1212+
Value resultingValue = topLevelForOp->getResult(resultNumber);
11851213

1186-
return getConsumerFromUses(resultingValue, containingOp->getBlock());
1214+
return getConsumerFromUses(resultingValue, topLevelForOp->getBlock());
11871215
}
11881216

11891217
/// Fetch the first untiled consumer of a scf.forall's result which is yielded
@@ -1302,8 +1330,8 @@ fixTerminatorSCFInParallel(RewriterBase &rewriter, scf::ForallOp newForallOp,
13021330
/// Implementation of fusing consumer of a single slice by computing the
13031331
/// slice of the consumer in-place for scf loop.
13041332
FailureOr<scf::SCFFuseConsumerOfSliceResult>
1305-
mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
1306-
Operation *candidateSliceOp) {
1333+
mlir::scf::tileAndFuseConsumerOfSliceImpl(RewriterBase &rewriter,
1334+
Operation *candidateSliceOp) {
13071335
if (!isa<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
13081336
candidateSliceOp))
13091337
return failure();
@@ -1337,22 +1365,25 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
13371365
if (isInsertSliceOp) {
13381366
auto forOp = candidateSliceOp->getParentOfType<scf::ForOp>();
13391367
oldLoopOp = forOp;
1340-
llvm::append_range(newOuts, forOp.getInits());
1341-
oldLoopBody = forOp.getBody();
13421368
initSize = forOp.getInits().size();
13431369
} else {
13441370
auto forallOp = candidateSliceOp->getParentOfType<scf::ForallOp>();
13451371
oldLoopOp = forallOp;
1346-
llvm::append_range(newOuts, forallOp.getOutputs());
1347-
oldLoopBody = forallOp.getBody();
13481372
initSize = forallOp.getOutputs().size();
13491373
rank = forallOp.getRank();
13501374
}
13511375

1352-
if (failed(checkAssumptionForLoop(oldLoopOp, consumerOp))) {
1376+
Operation *oldTopLevelLoop = oldLoopOp;
1377+
SmallVector<scf::ForOp> oldNestedForOps, newNestedForOps;
1378+
if (isInsertSliceOp) {
1379+
oldNestedForOps = getPerfectlyOuterLoops(cast<scf::ForOp>(oldLoopOp));
1380+
oldTopLevelLoop = oldNestedForOps.front();
1381+
}
1382+
if (failed(checkAssumptionForLoop(oldTopLevelLoop, consumerOp))) {
13531383
return rewriter.notifyMatchFailure(
1354-
oldLoopOp, "containing loop op should either yield just one value or "
1355-
"have the consumer op as its first user");
1384+
oldTopLevelLoop,
1385+
"containing loop op should either yield just one value or "
1386+
"have the consumer op as its first user");
13561387
}
13571388

13581389
OpBuilder::InsertionGuard g(rewriter);
@@ -1361,28 +1392,59 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
13611392
auto dstOp = cast<DestinationStyleOpInterface>(consumerOp);
13621393
SmallVector<Value> dpsInits =
13631394
llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; });
1364-
if (llvm::is_contained(dpsInits, oldLoopOp->getResult(resultNumber))) {
1395+
if (llvm::is_contained(dpsInits, oldTopLevelLoop->getResult(resultNumber))) {
13651396
return rewriter.notifyMatchFailure(
13661397
consumerOp,
13671398
"consumer op taking the result of scf.for as init is not supported");
13681399
}
1369-
newOuts.append(dpsInits);
1400+
SmallVector<Value> newInitAppend = dpsInits;
13701401

13711402
Location loc = oldLoopOp->getLoc();
13721403

13731404
// 3. Create new scf loop op.
13741405
rewriter.setInsertionPoint(consumerOp);
1406+
1407+
// 3.a Create new outer scf loops if necessary
1408+
bool isNestedForOps = isInsertSliceOp && oldNestedForOps.size() > 1;
1409+
if (isNestedForOps) {
1410+
for (auto &forOp : MutableArrayRef(oldNestedForOps).drop_back()) {
1411+
SmallVector<Value> newInits;
1412+
newInits = llvm::to_vector(forOp.getInits());
1413+
newInits.append(newInitAppend.begin(), newInitAppend.end());
1414+
auto newLoop = rewriter.create<scf::ForOp>(
1415+
forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
1416+
forOp.getStep(), newInits);
1417+
newInitAppend = llvm::map_to_vector(
1418+
newLoop.getRegionIterArgs().take_back(newInitAppend.size()),
1419+
[](BlockArgument bArg) -> Value { return bArg; });
1420+
rewriter.mergeBlocks(
1421+
forOp.getBody(), newLoop.getBody(),
1422+
newLoop.getBody()->getArguments().take_front(initSize + 1));
1423+
rewriter.replaceOp(
1424+
forOp, newLoop->getResults().take_front(forOp->getNumResults()));
1425+
newNestedForOps.push_back(newLoop);
1426+
}
1427+
rewriter.setInsertionPoint(oldNestedForOps.back());
1428+
}
1429+
1430+
// 3.b Create new inner most scf loop
13751431
Operation *newLoopOp = nullptr;
13761432
Block *newLoopBody = nullptr;
13771433
if (isInsertSliceOp) {
13781434
auto forOp = cast<scf::ForOp>(oldLoopOp);
1435+
llvm::append_range(newOuts, forOp.getInits());
1436+
newOuts.append(newInitAppend);
1437+
oldLoopBody = forOp.getBody();
13791438
auto newForOp = rewriter.create<scf::ForOp>(loc, forOp.getLowerBound(),
13801439
forOp.getUpperBound(),
13811440
forOp.getStep(), newOuts);
13821441
newLoopOp = newForOp;
13831442
newLoopBody = newForOp.getBody();
13841443
} else {
13851444
auto forallOp = cast<scf::ForallOp>(oldLoopOp);
1445+
llvm::append_range(newOuts, forallOp.getOutputs());
1446+
newOuts.append(newInitAppend);
1447+
oldLoopBody = forallOp.getBody();
13861448
auto newForallOp = rewriter.create<scf::ForallOp>(
13871449
loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
13881450
forallOp.getMixedStep(), newOuts, forallOp.getMapping());
@@ -1496,28 +1558,155 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
14961558
newForallOp.getBody()->getArguments().drop_front(rank + initSize));
14971559
}
14981560

1499-
// 12. Replace the result of scf loop and consumer op with new loop's results.
1561+
// 12. Restore outer loops from inner to outer
1562+
if (isNestedForOps) {
1563+
newNestedForOps.push_back(cast<scf::ForOp>(newLoopOp));
1564+
for (auto [outerLoop, innerLoop] :
1565+
llvm::zip_equal(MutableArrayRef(newNestedForOps).drop_back(),
1566+
MutableArrayRef(newNestedForOps).drop_front())) {
1567+
auto forOp = cast<scf::ForOp>(outerLoop);
1568+
auto outerLoopYield =
1569+
cast<scf::YieldOp>(forOp.getBody()->getTerminator());
1570+
SmallVector<Value> newYields =
1571+
llvm::to_vector(outerLoopYield.getOperands());
1572+
ValueRange additionalYields =
1573+
innerLoop->getResults().take_back(newInitAppend.size());
1574+
newYields.append(additionalYields.begin(), additionalYields.end());
1575+
rewriter.setInsertionPoint(outerLoopYield);
1576+
rewriter.replaceOpWithNewOp<scf::YieldOp>(outerLoopYield, newYields);
1577+
}
1578+
}
1579+
1580+
// 13. Replace the result of scf loop and consumer op with new loop's results.
15001581
for (auto &&[oldResult, newResult] :
15011582
llvm::zip_first(oldLoopOp->getResults(), newLoopOp->getResults())) {
15021583
rewriter.replaceAllUsesWith(oldResult, newResult);
15031584
}
15041585

1586+
Operation *newTopLevelLoop =
1587+
isNestedForOps ? newNestedForOps.front() : newLoopOp;
15051588
for (auto &&[oldResult, newResult] :
15061589
llvm::zip(consumerOp->getResults(),
1507-
newLoopOp->getResults().drop_front(initSize))) {
1590+
newTopLevelLoop->getResults().drop_front(initSize))) {
15081591
rewriter.replaceAllUsesWith(oldResult, newResult);
15091592
}
15101593

1511-
// 13. Need to erase the old scf loop and the cloned consumer op.
1594+
// 14. Need to erase the old scf loop and the cloned consumer op.
15121595
rewriter.eraseOp(oldLoopOp);
15131596
rewriter.eraseOp(clonedConsumerOp);
15141597

1598+
// 15. Need to erase the cloned insertSliceOp and unused extractSliceOp in
1599+
// avoid of complex domination analysis
1600+
assert(clonedInsertSliceOp->hasOneUse());
1601+
auto unUsedExtractOp =
1602+
cast<tensor::ExtractSliceOp>((*clonedInsertSliceOp->getUsers().begin()));
1603+
rewriter.eraseOp(unUsedExtractOp);
1604+
rewriter.eraseOp(clonedInsertSliceOp);
1605+
15151606
return scf::SCFFuseConsumerOfSliceResult{
15161607
consumerOpOperand,
15171608
&(tileAndFuseResult->tiledOps[0]->getOpOperand(operandNumber)),
15181609
tileAndFuseResult->tiledOps};
15191610
}
15201611

1612+
/// Get the result of top-level loop which yields the target InsertSliceOp. E.g
1613+
/// ```
1614+
/// %1 = scf.for
1615+
/// %2 = scf.for
1616+
/// %3 = scf.for
1617+
/// ...
1618+
/// %4 = insert
1619+
/// yield %4
1620+
/// %5 = insert %3
1621+
/// yield %5
1622+
/// yield %2
1623+
/// ```
1624+
/// @param targetSliceOp: %4 = insert
1625+
/// @param insertSliceOpChain: chain of all related insert sliceOp
1626+
/// @return resultValue: %1
1627+
static FailureOr<Value> getResultOfTopLevelLoopYieldInsertSliceOp(
1628+
Operation *targetSliceOp,
1629+
SmallVectorImpl<OffsetSizeAndStrideOpInterface> &insertSliceOpChain,
1630+
int curDepth = 0, int maxDepth = 5) {
1631+
assert(isa<OffsetSizeAndStrideOpInterface>(targetSliceOp));
1632+
// Control recursive time in avoid of stack overflow
1633+
if (curDepth > maxDepth)
1634+
return failure();
1635+
1636+
insertSliceOpChain.push_back(
1637+
cast<OffsetSizeAndStrideOpInterface>(targetSliceOp));
1638+
Value resultOfLoop;
1639+
if (auto sliceOp = dyn_cast<tensor::ParallelInsertSliceOp>(targetSliceOp)) {
1640+
Value destValue = sliceOp.getDest();
1641+
auto iterArg = cast<BlockArgument>(destValue);
1642+
auto forallOp = dyn_cast<scf::ForallOp>(iterArg.getOwner()->getParentOp());
1643+
if (!forallOp)
1644+
return failure();
1645+
resultOfLoop = forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg));
1646+
} else if (auto sliceOp = dyn_cast<tensor::InsertSliceOp>(targetSliceOp)) {
1647+
Value resultValue = sliceOp.getResult();
1648+
for (auto &useOperand : resultValue.getUses()) {
1649+
if (auto yieldOp = dyn_cast<scf::YieldOp>(useOperand.getOwner())) {
1650+
if (llvm::detail::isPresent(resultOfLoop))
1651+
return failure();
1652+
auto forOp = dyn_cast<LoopLikeOpInterface>(yieldOp->getParentOp());
1653+
if (!forOp)
1654+
return failure();
1655+
resultOfLoop = forOp->getResult(useOperand.getOperandNumber());
1656+
}
1657+
}
1658+
}
1659+
1660+
if (!llvm::detail::isPresent(resultOfLoop))
1661+
return failure();
1662+
1663+
while (true) {
1664+
bool walkThroughOuterLoop = false;
1665+
for (OpOperand &useOperand : resultOfLoop.getUses()) {
1666+
if (auto sliceOp =
1667+
dyn_cast<OffsetSizeAndStrideOpInterface>(useOperand.getOwner())) {
1668+
return getResultOfTopLevelLoopYieldInsertSliceOp(
1669+
sliceOp, insertSliceOpChain, curDepth + 1);
1670+
} else if (auto yieldOp = dyn_cast<scf::YieldOp>(useOperand.getOwner())) {
1671+
// walk through outer loop
1672+
auto forOp = dyn_cast<LoopLikeOpInterface>(yieldOp->getParentOp());
1673+
if (!forOp)
1674+
return failure();
1675+
resultOfLoop = forOp->getResult(useOperand.getOperandNumber());
1676+
walkThroughOuterLoop = true;
1677+
break;
1678+
}
1679+
}
1680+
if (!walkThroughOuterLoop)
1681+
break;
1682+
}
1683+
return resultOfLoop;
1684+
}
1685+
1686+
/// Fusing real consumer of a single slice even within complex nested loops via
1687+
/// multiple application of `tileAndFuseConsumerOfSliceImpl`.
1688+
FailureOr<scf::SCFFuseConsumerOfSliceResult>
1689+
mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
1690+
Operation *candidateSliceOp) {
1691+
SmallVector<OffsetSizeAndStrideOpInterface> sliceOpChain;
1692+
if (failed(getResultOfTopLevelLoopYieldInsertSliceOp(candidateSliceOp,
1693+
sliceOpChain)))
1694+
return failure();
1695+
1696+
FailureOr<scf::SCFFuseConsumerOfSliceResult> fuseConsumerResult;
1697+
// reverse from outer to inner
1698+
std::reverse(sliceOpChain.begin(), sliceOpChain.end());
1699+
// multiple application of `tileAndFuseConsumerOfSliceImpl`
1700+
for (auto &sliceOp : sliceOpChain) {
1701+
fuseConsumerResult = tileAndFuseConsumerOfSliceImpl(rewriter, sliceOp);
1702+
if (failed(fuseConsumerResult)) {
1703+
return rewriter.notifyMatchFailure(sliceOp,
1704+
"could not fuse consumer of sliceOp");
1705+
}
1706+
}
1707+
return fuseConsumerResult;
1708+
}
1709+
15211710
//===----------------------------------------------------------------------===//
15221711
// lowerToLoopsUsingSCFForOp implementation.
15231712
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)