Skip to content

Commit 0884a18

Browse files
committed
split extension for single loop and single uses respectively
1 parent 12083de commit 0884a18

File tree

2 files changed

+276
-30
lines changed

2 files changed

+276
-30
lines changed

mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp

Lines changed: 272 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1464,12 +1464,38 @@ checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp) {
14641464
/// failure otherwise.
14651465
static FailureOr<OpOperand *> getConsumerFromUses(Value val,
14661466
Block *containingOpBlock) {
1467-
// Step 1. Check that the value has exactly one use.
1468-
if (!llvm::hasSingleElement(val.getUses()))
1469-
return failure();
1467+
// Step 1. Check that the value has exactly one use excluding `insertSliceOp`
1468+
// or `ParallelInsertSliceOp`.
1469+
OpOperand *operand = nullptr;
1470+
for (auto &use : val.getUses()) {
1471+
Operation *user = use.getOwner();
1472+
if (isa<tensor::ParallelInsertSliceOp>(user))
1473+
continue;
1474+
if (isa<tensor::InsertSliceOp>(user)) {
1475+
// The only one use is expected as dummy extractSliceOp without any uses.
1476+
// For more details, please refer to:
1477+
// https://github.com/llvm/llvm-project/pull/88712#discussion_r1609384470
1478+
if (user->hasOneUse()) {
1479+
if (auto extractOp =
1480+
dyn_cast<tensor::ExtractSliceOp>(*user->getUsers().begin());
1481+
extractOp && extractOp->use_empty()) {
1482+
// Erase dummy extractSliceOp.
1483+
extractOp.erase();
1484+
// DO NOT erase `user` inside iteration of `getUses`.
1485+
user->moveBefore(&containingOpBlock->getOperations().back());
1486+
continue;
1487+
}
1488+
}
1489+
// Otherwise return.
1490+
return failure();
1491+
}
1492+
// Only one valid use expected
1493+
if (operand)
1494+
return failure();
1495+
operand = &use;
1496+
}
14701497
// Step 2. Get uses.
1471-
OpOperand &operand = (*val.getUses().begin());
1472-
Operation *consumerOp = operand.getOwner();
1498+
Operation *consumerOp = operand->getOwner();
14731499
// TODO: We have to init result of consumer before scf.for, use
14741500
// DestinationStyleOpInterface to get result shape from init for now.
14751501
// Add support for other op such as op has InferTypeOpInterface.
@@ -1478,7 +1504,54 @@ static FailureOr<OpOperand *> getConsumerFromUses(Value val,
14781504
return failure();
14791505
if (containingOpBlock != consumerOp->getBlock())
14801506
return failure();
1481-
return &operand;
1507+
return operand;
1508+
}
1509+
1510+
/// Recursively find the outer nest loops of given loop(included) while the
1511+
/// predict function succeed, sorted from outer to inner.
1512+
///
1513+
/// @param loop: target loop, note that this loop will be also included. I.e.
1514+
/// if no other nest loops were found, just return itself.
1515+
/// @param pred: predict function, the termination condition of recursive
1516+
/// process.
1517+
/// @return Outer Nest Loops: nest loops outside given target loop(included).
1518+
///
1519+
/// E.g.
1520+
///
1521+
/// ```
1522+
/// %0 = scf.for()
1523+
/// %1 = scf.for()
1524+
/// %2 = scf.for()
1525+
/// ```
1526+
///
1527+
/// If `%2 = scf.for` is given without specific prediction function, this
1528+
/// function will return three nest loops: %0 + %1 + %2.
1529+
static SmallVector<LoopLikeOpInterface> getOuterNestLoopsWhile(
1530+
LoopLikeOpInterface loop,
1531+
const std::function<LogicalResult(LoopLikeOpInterface)> &pred) {
1532+
SmallVector<LoopLikeOpInterface> nestLoops = {loop};
1533+
auto outerLoop = dyn_cast<LoopLikeOpInterface>(loop->getParentOp());
1534+
while (outerLoop && succeeded(pred(outerLoop))) {
1535+
nestLoops.push_back(outerLoop);
1536+
outerLoop = dyn_cast<LoopLikeOpInterface>(outerLoop->getParentOp());
1537+
}
1538+
// sorted from outer to inner
1539+
return {nestLoops.rbegin(), nestLoops.rend()};
1540+
}
1541+
1542+
/// Check if it is the ForOp that yield the result of inner loop
1543+
static LogicalResult isForOpYieldResultOfInnerLoop(LoopLikeOpInterface loop) {
1544+
if (auto forOp = dyn_cast<scf::ForOp>(loop.getOperation())) {
1545+
Block::OpListType &opsInLoopBody = forOp.getBody()->getOperations();
1546+
for (auto &&[index, op] : llvm::enumerate(opsInLoopBody)) {
1547+
// If the orderIndex of inner loop is the last second one before the
1548+
// yieldOp of ForOp, the given loop must yield the result of inner loop.
1549+
if (isa<LoopLikeOpInterface>(op)) {
1550+
return success((index + 2) == opsInLoopBody.size());
1551+
}
1552+
}
1553+
}
1554+
return failure();
14821555
}
14831556

14841557
/// Fetch the untiled consumer of a scf.for's result which is yielded by a
@@ -1498,9 +1571,11 @@ getUntiledConsumerFromSlice(tensor::InsertSliceOp candidateSliceOp) {
14981571
auto forOp = dyn_cast<scf::ForOp>(containingOp);
14991572
if (!forOp)
15001573
return failure();
1501-
Value resultingValue = forOp->getResult(resultNumber);
1574+
LoopLikeOpInterface topLevelForOp =
1575+
getOuterNestLoopsWhile(forOp, isForOpYieldResultOfInnerLoop).front();
1576+
Value resultingValue = topLevelForOp->getResult(resultNumber);
15021577

1503-
return getConsumerFromUses(resultingValue, containingOp->getBlock());
1578+
return getConsumerFromUses(resultingValue, topLevelForOp->getBlock());
15041579
}
15051580

15061581
/// Fetch the first untiled consumer of a scf.forall's result which is yielded
@@ -1618,9 +1693,9 @@ fixTerminatorSCFInParallel(RewriterBase &rewriter, scf::ForallOp newForallOp,
16181693

16191694
/// Implementation of fusing consumer of a single slice by computing the
16201695
/// slice of the consumer in-place for scf loop.
1621-
FailureOr<scf::SCFFuseConsumerOfSliceResult>
1622-
mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
1623-
Operation *candidateSliceOp) {
1696+
static FailureOr<scf::SCFFuseConsumerOfSliceResult>
1697+
tileAndFuseConsumerOfSliceImpl(RewriterBase &rewriter,
1698+
Operation *candidateSliceOp) {
16241699
if (!isa<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
16251700
candidateSliceOp))
16261701
return failure();
@@ -1654,52 +1729,99 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
16541729
if (isInsertSliceOp) {
16551730
auto forOp = candidateSliceOp->getParentOfType<scf::ForOp>();
16561731
oldLoopOp = forOp;
1657-
llvm::append_range(newOuts, forOp.getInits());
1658-
oldLoopBody = forOp.getBody();
16591732
initSize = forOp.getInits().size();
16601733
} else {
16611734
auto forallOp = candidateSliceOp->getParentOfType<scf::ForallOp>();
16621735
oldLoopOp = forallOp;
1663-
llvm::append_range(newOuts, forallOp.getOutputs());
1664-
oldLoopBody = forallOp.getBody();
16651736
initSize = forallOp.getOutputs().size();
16661737
rank = forallOp.getRank();
16671738
}
16681739

1669-
if (failed(checkAssumptionForLoop(oldLoopOp, consumerOp))) {
1740+
// There are two possible cases regarding `oldLoopOp` here:
1741+
// 1. single `scf.forall` or `scf.for`.
1742+
// 2. inner-most `scf.for` insider nest `scf.loop` structure, where the
1743+
// top-level loop is the outer-most one of these nested loops.
1744+
Operation *oldTopLevelLoop = oldLoopOp;
1745+
SmallVector<LoopLikeOpInterface> oldNestedForOps, newNestedForOps;
1746+
if (isInsertSliceOp) {
1747+
oldNestedForOps =
1748+
getOuterNestLoopsWhile(cast<LoopLikeOpInterface>(oldTopLevelLoop),
1749+
isForOpYieldResultOfInnerLoop);
1750+
oldTopLevelLoop = oldNestedForOps.front();
1751+
}
1752+
1753+
if (failed(checkAssumptionForLoop(oldTopLevelLoop, consumerOp))) {
16701754
return rewriter.notifyMatchFailure(
1671-
oldLoopOp, "containing loop op should either yield just one value or "
1672-
"have the consumer op as its first user");
1755+
oldTopLevelLoop,
1756+
"containing loop op should either yield just one value or "
1757+
"have the consumer op as its first user");
16731758
}
16741759

16751760
OpBuilder::InsertionGuard g(rewriter);
16761761

16771762
// 2. Check consumer is not using scf loop's output as init.
1678-
auto dstOp = cast<DestinationStyleOpInterface>(consumerOp);
1763+
auto dstOp = dyn_cast<DestinationStyleOpInterface>(consumerOp);
1764+
if (!dstOp)
1765+
return rewriter.notifyMatchFailure(consumerOp,
1766+
"consumer op is not DPS operation");
16791767
SmallVector<Value> dpsInits =
16801768
llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; });
1681-
if (llvm::is_contained(dpsInits, oldLoopOp->getResult(resultNumber))) {
1769+
if (llvm::is_contained(dpsInits, oldTopLevelLoop->getResult(resultNumber))) {
16821770
return rewriter.notifyMatchFailure(
16831771
consumerOp,
16841772
"consumer op taking the result of scf.for as init is not supported");
16851773
}
1686-
newOuts.append(dpsInits);
1774+
SmallVector<Value> newInitAppend = dpsInits;
16871775

16881776
Location loc = oldLoopOp->getLoc();
16891777

16901778
// 3. Create new scf loop op.
16911779
rewriter.setInsertionPoint(consumerOp);
1780+
1781+
// 3.a Create new outer scf loops with new Inits only if nested `scf.for`
1782+
// case was found.
1783+
bool isNestedForOps = isInsertSliceOp && oldNestedForOps.size() > 1;
1784+
if (isNestedForOps) {
1785+
for (auto &&[index, loopOp] :
1786+
llvm::enumerate(MutableArrayRef(oldNestedForOps).drop_back())) {
1787+
auto forOp = cast<scf::ForOp>(loopOp);
1788+
SmallVector<Value> newInits;
1789+
newInits = llvm::to_vector(forOp.getInits());
1790+
newInits.append(newInitAppend.begin(), newInitAppend.end());
1791+
auto newLoop = rewriter.create<scf::ForOp>(
1792+
forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
1793+
forOp.getStep(), newInits);
1794+
newInitAppend = llvm::map_to_vector(
1795+
newLoop.getRegionIterArgs().take_back(newInitAppend.size()),
1796+
[](BlockArgument bArg) -> Value { return bArg; });
1797+
rewriter.mergeBlocks(
1798+
forOp.getBody(), newLoop.getBody(),
1799+
newLoop.getBody()->getArguments().take_front(initSize + 1));
1800+
rewriter.replaceOp(
1801+
forOp, newLoop->getResults().take_front(forOp->getNumResults()));
1802+
newNestedForOps.push_back(newLoop);
1803+
rewriter.setInsertionPointAfter(oldNestedForOps[index + 1]);
1804+
}
1805+
}
1806+
1807+
// 3.b Create new inner-most scf loop
16921808
Operation *newLoopOp = nullptr;
16931809
Block *newLoopBody = nullptr;
16941810
if (isInsertSliceOp) {
16951811
auto forOp = cast<scf::ForOp>(oldLoopOp);
1812+
oldLoopBody = forOp.getBody();
1813+
llvm::append_range(newOuts, forOp.getInits());
1814+
newOuts.append(newInitAppend);
16961815
auto newForOp = rewriter.create<scf::ForOp>(loc, forOp.getLowerBound(),
16971816
forOp.getUpperBound(),
16981817
forOp.getStep(), newOuts);
16991818
newLoopOp = newForOp;
17001819
newLoopBody = newForOp.getBody();
17011820
} else {
17021821
auto forallOp = cast<scf::ForallOp>(oldLoopOp);
1822+
oldLoopBody = forallOp.getBody();
1823+
llvm::append_range(newOuts, forallOp.getOutputs());
1824+
newOuts.append(newInitAppend);
17031825
auto newForallOp = rewriter.create<scf::ForallOp>(
17041826
loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
17051827
forallOp.getMixedStep(), newOuts, forallOp.getMapping());
@@ -1813,19 +1935,41 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
18131935
newForallOp.getBody()->getArguments().drop_front(rank + initSize));
18141936
}
18151937

1816-
// 12. Replace the result of scf loop and consumer op with new loop's results.
1938+
// 12. Restore outer loops from inner to outer only if nested `scf.for`
1939+
// case was found.
1940+
if (isNestedForOps) {
1941+
newNestedForOps.push_back(cast<scf::ForOp>(newLoopOp));
1942+
for (auto [outerLoop, innerLoop] :
1943+
llvm::zip_equal(MutableArrayRef(newNestedForOps).drop_back(),
1944+
MutableArrayRef(newNestedForOps).drop_front())) {
1945+
auto forOp = cast<scf::ForOp>(outerLoop);
1946+
auto outerLoopYield =
1947+
cast<scf::YieldOp>(forOp.getBody()->getTerminator());
1948+
SmallVector<Value> newYields =
1949+
llvm::to_vector(outerLoopYield.getOperands());
1950+
ValueRange additionalYields =
1951+
innerLoop->getResults().take_back(newInitAppend.size());
1952+
newYields.append(additionalYields.begin(), additionalYields.end());
1953+
rewriter.setInsertionPoint(outerLoopYield);
1954+
rewriter.replaceOpWithNewOp<scf::YieldOp>(outerLoopYield, newYields);
1955+
}
1956+
}
1957+
1958+
// 13. Replace the result of scf loop and consumer op with new loop's results.
18171959
for (auto &&[oldResult, newResult] :
18181960
llvm::zip_first(oldLoopOp->getResults(), newLoopOp->getResults())) {
18191961
rewriter.replaceAllUsesWith(oldResult, newResult);
18201962
}
18211963

1964+
Operation *newTopLevelLoop =
1965+
isNestedForOps ? newNestedForOps.front() : newLoopOp;
18221966
for (auto &&[oldResult, newResult] :
18231967
llvm::zip(consumerOp->getResults(),
1824-
newLoopOp->getResults().drop_front(initSize))) {
1968+
newTopLevelLoop->getResults().drop_front(initSize))) {
18251969
rewriter.replaceAllUsesWith(oldResult, newResult);
18261970
}
18271971

1828-
// 13. Need to erase the old scf loop and the cloned consumer op.
1972+
// 14. Need to erase the old scf loop and the cloned consumer op.
18291973
rewriter.eraseOp(oldLoopOp);
18301974
rewriter.eraseOp(clonedConsumerOp);
18311975

@@ -1835,6 +1979,110 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
18351979
tileAndFuseResult->tiledOps};
18361980
}
18371981

1982+
/// Get the real consumers from candidate InsertSliceOp. E.g
1983+
///
1984+
/// ```
1985+
/// %1 = scf.for
1986+
/// %2 = scf.for
1987+
/// %3 = scf.for
1988+
/// ...
1989+
/// %4 = insert
1990+
/// yield %4
1991+
/// %5 = insert %3
1992+
/// yield %5
1993+
/// yield %2
1994+
/// %6 = consumerOp ins(%1)
1995+
/// ```
1996+
///
1997+
/// @param candidateSliceOp: %4 = insert
1998+
/// @param forwardSlice: in-out parameter populated by forward insertSliceOps
1999+
/// @return OpOperand consumers: %6 = consumerOp ins(%1)
2000+
static FailureOr<SmallVector<OpOperand *>> getRealConsumersFromInsertSliceOp(
2001+
Operation *candidateSliceOp,
2002+
SmallVector<OffsetSizeAndStrideOpInterface> &forwardSlice,
2003+
unsigned curDepth = 0, unsigned maxDepth = 5) {
2004+
if (!isa<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
2005+
candidateSliceOp))
2006+
return failure();
2007+
// Control recursive time in avoid of stack overflow
2008+
if (curDepth > maxDepth)
2009+
return failure();
2010+
2011+
forwardSlice.push_back(
2012+
cast<OffsetSizeAndStrideOpInterface>(candidateSliceOp));
2013+
Value resultOfLoop;
2014+
if (auto sliceOp =
2015+
dyn_cast<tensor::ParallelInsertSliceOp>(candidateSliceOp)) {
2016+
Value destValue = sliceOp.getDest();
2017+
auto iterArg = cast<BlockArgument>(destValue);
2018+
auto forallOp = dyn_cast<scf::ForallOp>(iterArg.getOwner()->getParentOp());
2019+
if (!forallOp)
2020+
return failure();
2021+
resultOfLoop = forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg));
2022+
} else if (auto sliceOp = dyn_cast<tensor::InsertSliceOp>(candidateSliceOp)) {
2023+
Value resultValue = sliceOp.getResult();
2024+
for (auto &useOperand : resultValue.getUses()) {
2025+
if (auto yieldOp = dyn_cast<scf::YieldOp>(useOperand.getOwner())) {
2026+
if (llvm::detail::isPresent(resultOfLoop))
2027+
return failure();
2028+
auto forOp = dyn_cast<LoopLikeOpInterface>(yieldOp->getParentOp());
2029+
if (!forOp)
2030+
return failure();
2031+
resultOfLoop = forOp->getResult(useOperand.getOperandNumber());
2032+
}
2033+
}
2034+
}
2035+
2036+
if (!llvm::detail::isPresent(resultOfLoop))
2037+
return failure();
2038+
2039+
bool traverseUpperLoop;
2040+
do {
2041+
traverseUpperLoop = false;
2042+
for (OpOperand &useOperand : resultOfLoop.getUses()) {
2043+
if (auto sliceOp =
2044+
dyn_cast<OffsetSizeAndStrideOpInterface>(useOperand.getOwner())) {
2045+
return getRealConsumersFromInsertSliceOp(sliceOp, forwardSlice,
2046+
curDepth + 1);
2047+
}
2048+
if (auto yieldOp = dyn_cast<scf::YieldOp>(useOperand.getOwner())) {
2049+
// Walk through outer loop.
2050+
auto forOp = dyn_cast<LoopLikeOpInterface>(yieldOp->getParentOp());
2051+
if (!forOp)
2052+
return failure();
2053+
resultOfLoop = forOp->getResult(useOperand.getOperandNumber());
2054+
traverseUpperLoop = true;
2055+
break;
2056+
}
2057+
}
2058+
} while (traverseUpperLoop);
2059+
// Return all operands using result of top level loop.
2060+
return llvm::map_to_vector(resultOfLoop.getUses(),
2061+
[](OpOperand &u) -> OpOperand * { return &u; });
2062+
}
2063+
2064+
/// Fusing real consumer of a single slice even within complex nested loops via
2065+
/// multiple application of `tileAndFuseConsumerOfSliceImpl`.
2066+
FailureOr<scf::SCFFuseConsumerOfSliceResult>
2067+
mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
2068+
Operation *candidateSliceOp) {
2069+
SmallVector<OffsetSizeAndStrideOpInterface> forwardSlice;
2070+
if (failed(getRealConsumersFromInsertSliceOp(candidateSliceOp, forwardSlice)))
2071+
return failure();
2072+
2073+
FailureOr<scf::SCFFuseConsumerOfSliceResult> fuseConsumerResult;
2074+
// Reverse forward slice from outer to inner.
2075+
std::reverse(forwardSlice.begin(), forwardSlice.end());
2076+
// Multiple application of `tileAndFuseConsumerOfSliceImpl`.
2077+
for (auto &sliceOp : forwardSlice) {
2078+
fuseConsumerResult = tileAndFuseConsumerOfSliceImpl(rewriter, sliceOp);
2079+
if (failed(fuseConsumerResult))
2080+
return rewriter.notifyMatchFailure(sliceOp,
2081+
"could not fuse consumer of sliceOp");
2082+
}
2083+
return fuseConsumerResult;
2084+
}
2085+
18382086
//===----------------------------------------------------------------------===//
18392087
// lowerToLoopsUsingSCFForOp implementation.
18402088
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)