@@ -1464,12 +1464,38 @@ checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp) {
1464
1464
// / failure otherwise.
1465
1465
static FailureOr<OpOperand *> getConsumerFromUses (Value val,
1466
1466
Block *containingOpBlock) {
1467
- // Step 1. Check that the value has exactly one use.
1468
- if (!llvm::hasSingleElement (val.getUses ()))
1469
- return failure ();
1467
+ // Step 1. Check that the value has exactly one use excluding `insertSliceOp`
1468
+ // or `ParallelInsertSliceOp`.
1469
+ OpOperand *operand = nullptr ;
1470
+ for (auto &use : val.getUses ()) {
1471
+ Operation *user = use.getOwner ();
1472
+ if (isa<tensor::ParallelInsertSliceOp>(user))
1473
+ continue ;
1474
+ if (isa<tensor::InsertSliceOp>(user)) {
1475
+ // The only one use is expected as dummy extractSliceOp without any uses.
1476
+ // For more details, please refer to:
1477
+ // https://github.com/llvm/llvm-project/pull/88712#discussion_r1609384470
1478
+ if (user->hasOneUse ()) {
1479
+ if (auto extractOp =
1480
+ dyn_cast<tensor::ExtractSliceOp>(*user->getUsers ().begin ());
1481
+ extractOp && extractOp->use_empty ()) {
1482
+ // Erase dummy extractSliceOp.
1483
+ extractOp.erase ();
1484
+ // DO NOT erase `user` inside iteration of `getUses`.
1485
+ user->moveBefore (&containingOpBlock->getOperations ().back ());
1486
+ continue ;
1487
+ }
1488
+ }
1489
+ // Otherwise return.
1490
+ return failure ();
1491
+ }
1492
+ // Only one valid use expected
1493
+ if (operand)
1494
+ return failure ();
1495
+ operand = &use;
1496
+ }
1470
1497
// Step 2. Get uses.
1471
- OpOperand &operand = (*val.getUses ().begin ());
1472
- Operation *consumerOp = operand.getOwner ();
1498
+ Operation *consumerOp = operand->getOwner ();
1473
1499
// TODO: We have to init result of consumer before scf.for, use
1474
1500
// DestinationStyleOpInterface to get result shape from init for now.
1475
1501
// Add support for other op such as op has InferTypeOpInterface.
@@ -1478,7 +1504,54 @@ static FailureOr<OpOperand *> getConsumerFromUses(Value val,
1478
1504
return failure ();
1479
1505
if (containingOpBlock != consumerOp->getBlock ())
1480
1506
return failure ();
1481
- return &operand;
1507
+ return operand;
1508
+ }
1509
+
1510
+ // / Recursively find the outer nest loops of given loop(included) while the
1511
+ // / predict function succeed, sorted from outer to inner.
1512
+ // /
1513
+ // / @param loop: target loop, note that this loop will be also included. I.e.
1514
+ // / if no other nest loops were found, just return itself.
1515
+ // / @param pred: predict function, the termination condition of recursive
1516
+ // / process.
1517
+ // / @return Outer Nest Loops: nest loops outside given target loop(included).
1518
+ // /
1519
+ // / E.g.
1520
+ // /
1521
+ // / ```
1522
+ // / %0 = scf.for()
1523
+ // / %1 = scf.for()
1524
+ // / %2 = scf.for()
1525
+ // / ```
1526
+ // /
1527
+ // / If `%2 = scf.for` is given without specific prediction function, this
1528
+ // / function will return three nest loops: %0 + %1 + %2.
1529
+ static SmallVector<LoopLikeOpInterface> getOuterNestLoopsWhile (
1530
+ LoopLikeOpInterface loop,
1531
+ const std::function<LogicalResult(LoopLikeOpInterface)> &pred) {
1532
+ SmallVector<LoopLikeOpInterface> nestLoops = {loop};
1533
+ auto outerLoop = dyn_cast<LoopLikeOpInterface>(loop->getParentOp ());
1534
+ while (outerLoop && succeeded (pred (outerLoop))) {
1535
+ nestLoops.push_back (outerLoop);
1536
+ outerLoop = dyn_cast<LoopLikeOpInterface>(outerLoop->getParentOp ());
1537
+ }
1538
+ // sorted from outer to inner
1539
+ return {nestLoops.rbegin (), nestLoops.rend ()};
1540
+ }
1541
+
1542
+ // / Check if it is the ForOp that yield the result of inner loop
1543
+ static LogicalResult isForOpYieldResultOfInnerLoop (LoopLikeOpInterface loop) {
1544
+ if (auto forOp = dyn_cast<scf::ForOp>(loop.getOperation ())) {
1545
+ Block::OpListType &opsInLoopBody = forOp.getBody ()->getOperations ();
1546
+ for (auto &&[index, op] : llvm::enumerate (opsInLoopBody)) {
1547
+ // If the orderIndex of inner loop is the last second one before the
1548
+ // yieldOp of ForOp, the given loop must yield the result of inner loop.
1549
+ if (isa<LoopLikeOpInterface>(op)) {
1550
+ return success ((index + 2 ) == opsInLoopBody.size ());
1551
+ }
1552
+ }
1553
+ }
1554
+ return failure ();
1482
1555
}
1483
1556
1484
1557
// / Fetch the untiled consumer of a scf.for's result which is yielded by a
@@ -1498,9 +1571,11 @@ getUntiledConsumerFromSlice(tensor::InsertSliceOp candidateSliceOp) {
1498
1571
auto forOp = dyn_cast<scf::ForOp>(containingOp);
1499
1572
if (!forOp)
1500
1573
return failure ();
1501
- Value resultingValue = forOp->getResult (resultNumber);
1574
+ LoopLikeOpInterface topLevelForOp =
1575
+ getOuterNestLoopsWhile (forOp, isForOpYieldResultOfInnerLoop).front ();
1576
+ Value resultingValue = topLevelForOp->getResult (resultNumber);
1502
1577
1503
- return getConsumerFromUses (resultingValue, containingOp ->getBlock ());
1578
+ return getConsumerFromUses (resultingValue, topLevelForOp ->getBlock ());
1504
1579
}
1505
1580
1506
1581
// / Fetch the first untiled consumer of a scf.forall's result which is yielded
@@ -1618,9 +1693,9 @@ fixTerminatorSCFInParallel(RewriterBase &rewriter, scf::ForallOp newForallOp,
1618
1693
1619
1694
// / Implementation of fusing consumer of a single slice by computing the
1620
1695
// / slice of the consumer in-place for scf loop.
1621
- FailureOr<scf::SCFFuseConsumerOfSliceResult>
1622
- mlir::scf::tileAndFuseConsumerOfSlice (RewriterBase &rewriter,
1623
- Operation *candidateSliceOp) {
1696
+ static FailureOr<scf::SCFFuseConsumerOfSliceResult>
1697
+ tileAndFuseConsumerOfSliceImpl (RewriterBase &rewriter,
1698
+ Operation *candidateSliceOp) {
1624
1699
if (!isa<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
1625
1700
candidateSliceOp))
1626
1701
return failure ();
@@ -1654,52 +1729,99 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
1654
1729
if (isInsertSliceOp) {
1655
1730
auto forOp = candidateSliceOp->getParentOfType <scf::ForOp>();
1656
1731
oldLoopOp = forOp;
1657
- llvm::append_range (newOuts, forOp.getInits ());
1658
- oldLoopBody = forOp.getBody ();
1659
1732
initSize = forOp.getInits ().size ();
1660
1733
} else {
1661
1734
auto forallOp = candidateSliceOp->getParentOfType <scf::ForallOp>();
1662
1735
oldLoopOp = forallOp;
1663
- llvm::append_range (newOuts, forallOp.getOutputs ());
1664
- oldLoopBody = forallOp.getBody ();
1665
1736
initSize = forallOp.getOutputs ().size ();
1666
1737
rank = forallOp.getRank ();
1667
1738
}
1668
1739
1669
- if (failed (checkAssumptionForLoop (oldLoopOp, consumerOp))) {
1740
+ // There are two possible cases regarding `oldLoopOp` here:
1741
+ // 1. single `scf.forall` or `scf.for`.
1742
+ // 2. inner-most `scf.for` insider nest `scf.loop` structure, where the
1743
+ // top-level loop is the outer-most one of these nested loops.
1744
+ Operation *oldTopLevelLoop = oldLoopOp;
1745
+ SmallVector<LoopLikeOpInterface> oldNestedForOps, newNestedForOps;
1746
+ if (isInsertSliceOp) {
1747
+ oldNestedForOps =
1748
+ getOuterNestLoopsWhile (cast<LoopLikeOpInterface>(oldTopLevelLoop),
1749
+ isForOpYieldResultOfInnerLoop);
1750
+ oldTopLevelLoop = oldNestedForOps.front ();
1751
+ }
1752
+
1753
+ if (failed (checkAssumptionForLoop (oldTopLevelLoop, consumerOp))) {
1670
1754
return rewriter.notifyMatchFailure (
1671
- oldLoopOp, " containing loop op should either yield just one value or "
1672
- " have the consumer op as its first user" );
1755
+ oldTopLevelLoop,
1756
+ " containing loop op should either yield just one value or "
1757
+ " have the consumer op as its first user" );
1673
1758
}
1674
1759
1675
1760
OpBuilder::InsertionGuard g (rewriter);
1676
1761
1677
1762
// 2. Check consumer is not using scf loop's output as init.
1678
- auto dstOp = cast<DestinationStyleOpInterface>(consumerOp);
1763
+ auto dstOp = dyn_cast<DestinationStyleOpInterface>(consumerOp);
1764
+ if (!dstOp)
1765
+ return rewriter.notifyMatchFailure (consumerOp,
1766
+ " consumer op is not DPS operation" );
1679
1767
SmallVector<Value> dpsInits =
1680
1768
llvm::map_to_vector (dstOp.getDpsInits (), [](Value v) { return v; });
1681
- if (llvm::is_contained (dpsInits, oldLoopOp ->getResult (resultNumber))) {
1769
+ if (llvm::is_contained (dpsInits, oldTopLevelLoop ->getResult (resultNumber))) {
1682
1770
return rewriter.notifyMatchFailure (
1683
1771
consumerOp,
1684
1772
" consumer op taking the result of scf.for as init is not supported" );
1685
1773
}
1686
- newOuts. append ( dpsInits) ;
1774
+ SmallVector<Value> newInitAppend = dpsInits;
1687
1775
1688
1776
Location loc = oldLoopOp->getLoc ();
1689
1777
1690
1778
// 3. Create new scf loop op.
1691
1779
rewriter.setInsertionPoint (consumerOp);
1780
+
1781
+ // 3.a Create new outer scf loops with new Inits only if nested `scf.for`
1782
+ // case was found.
1783
+ bool isNestedForOps = isInsertSliceOp && oldNestedForOps.size () > 1 ;
1784
+ if (isNestedForOps) {
1785
+ for (auto &&[index, loopOp] :
1786
+ llvm::enumerate (MutableArrayRef (oldNestedForOps).drop_back ())) {
1787
+ auto forOp = cast<scf::ForOp>(loopOp);
1788
+ SmallVector<Value> newInits;
1789
+ newInits = llvm::to_vector (forOp.getInits ());
1790
+ newInits.append (newInitAppend.begin (), newInitAppend.end ());
1791
+ auto newLoop = rewriter.create <scf::ForOp>(
1792
+ forOp.getLoc (), forOp.getLowerBound (), forOp.getUpperBound (),
1793
+ forOp.getStep (), newInits);
1794
+ newInitAppend = llvm::map_to_vector (
1795
+ newLoop.getRegionIterArgs ().take_back (newInitAppend.size ()),
1796
+ [](BlockArgument bArg) -> Value { return bArg; });
1797
+ rewriter.mergeBlocks (
1798
+ forOp.getBody (), newLoop.getBody (),
1799
+ newLoop.getBody ()->getArguments ().take_front (initSize + 1 ));
1800
+ rewriter.replaceOp (
1801
+ forOp, newLoop->getResults ().take_front (forOp->getNumResults ()));
1802
+ newNestedForOps.push_back (newLoop);
1803
+ rewriter.setInsertionPointAfter (oldNestedForOps[index + 1 ]);
1804
+ }
1805
+ }
1806
+
1807
+ // 3.b Create new inner-most scf loop
1692
1808
Operation *newLoopOp = nullptr ;
1693
1809
Block *newLoopBody = nullptr ;
1694
1810
if (isInsertSliceOp) {
1695
1811
auto forOp = cast<scf::ForOp>(oldLoopOp);
1812
+ oldLoopBody = forOp.getBody ();
1813
+ llvm::append_range (newOuts, forOp.getInits ());
1814
+ newOuts.append (newInitAppend);
1696
1815
auto newForOp = rewriter.create <scf::ForOp>(loc, forOp.getLowerBound (),
1697
1816
forOp.getUpperBound (),
1698
1817
forOp.getStep (), newOuts);
1699
1818
newLoopOp = newForOp;
1700
1819
newLoopBody = newForOp.getBody ();
1701
1820
} else {
1702
1821
auto forallOp = cast<scf::ForallOp>(oldLoopOp);
1822
+ oldLoopBody = forallOp.getBody ();
1823
+ llvm::append_range (newOuts, forallOp.getOutputs ());
1824
+ newOuts.append (newInitAppend);
1703
1825
auto newForallOp = rewriter.create <scf::ForallOp>(
1704
1826
loc, forallOp.getMixedLowerBound (), forallOp.getMixedUpperBound (),
1705
1827
forallOp.getMixedStep (), newOuts, forallOp.getMapping ());
@@ -1813,19 +1935,41 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
1813
1935
newForallOp.getBody ()->getArguments ().drop_front (rank + initSize));
1814
1936
}
1815
1937
1816
- // 12. Replace the result of scf loop and consumer op with new loop's results.
1938
+ // 12. Restore outer loops from inner to outer only if nested `scf.for`
1939
+ // case was found.
1940
+ if (isNestedForOps) {
1941
+ newNestedForOps.push_back (cast<scf::ForOp>(newLoopOp));
1942
+ for (auto [outerLoop, innerLoop] :
1943
+ llvm::zip_equal (MutableArrayRef (newNestedForOps).drop_back (),
1944
+ MutableArrayRef (newNestedForOps).drop_front ())) {
1945
+ auto forOp = cast<scf::ForOp>(outerLoop);
1946
+ auto outerLoopYield =
1947
+ cast<scf::YieldOp>(forOp.getBody ()->getTerminator ());
1948
+ SmallVector<Value> newYields =
1949
+ llvm::to_vector (outerLoopYield.getOperands ());
1950
+ ValueRange additionalYields =
1951
+ innerLoop->getResults ().take_back (newInitAppend.size ());
1952
+ newYields.append (additionalYields.begin (), additionalYields.end ());
1953
+ rewriter.setInsertionPoint (outerLoopYield);
1954
+ rewriter.replaceOpWithNewOp <scf::YieldOp>(outerLoopYield, newYields);
1955
+ }
1956
+ }
1957
+
1958
+ // 13. Replace the result of scf loop and consumer op with new loop's results.
1817
1959
for (auto &&[oldResult, newResult] :
1818
1960
llvm::zip_first (oldLoopOp->getResults (), newLoopOp->getResults ())) {
1819
1961
rewriter.replaceAllUsesWith (oldResult, newResult);
1820
1962
}
1821
1963
1964
+ Operation *newTopLevelLoop =
1965
+ isNestedForOps ? newNestedForOps.front () : newLoopOp;
1822
1966
for (auto &&[oldResult, newResult] :
1823
1967
llvm::zip (consumerOp->getResults (),
1824
- newLoopOp ->getResults ().drop_front (initSize))) {
1968
+ newTopLevelLoop ->getResults ().drop_front (initSize))) {
1825
1969
rewriter.replaceAllUsesWith (oldResult, newResult);
1826
1970
}
1827
1971
1828
- // 13 . Need to erase the old scf loop and the cloned consumer op.
1972
+ // 14 . Need to erase the old scf loop and the cloned consumer op.
1829
1973
rewriter.eraseOp (oldLoopOp);
1830
1974
rewriter.eraseOp (clonedConsumerOp);
1831
1975
@@ -1835,6 +1979,110 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
1835
1979
tileAndFuseResult->tiledOps };
1836
1980
}
1837
1981
1982
+ // / Get the real consumers from candidate InsertSliceOp. E.g
1983
+ // /
1984
+ // / ```
1985
+ // / %1 = scf.for
1986
+ // / %2 = scf.for
1987
+ // / %3 = scf.for
1988
+ // / ...
1989
+ // / %4 = insert
1990
+ // / yield %4
1991
+ // / %5 = insert %3
1992
+ // / yield %5
1993
+ // / yield %2
1994
+ // / %6 = consumerOp ins(%1)
1995
+ // / ```
1996
+ // /
1997
+ // / @param candidateSliceOp: %4 = insert
1998
+ // / @param forwardSlice: in-out parameter populated by forward insertSliceOps
1999
+ // / @return OpOperand consumers: %6 = consumerOp ins(%1)
2000
+ static FailureOr<SmallVector<OpOperand *>> getRealConsumersFromInsertSliceOp (
2001
+ Operation *candidateSliceOp,
2002
+ SmallVector<OffsetSizeAndStrideOpInterface> &forwardSlice,
2003
+ unsigned curDepth = 0 , unsigned maxDepth = 5 ) {
2004
+ if (!isa<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
2005
+ candidateSliceOp))
2006
+ return failure ();
2007
+ // Control recursive time in avoid of stack overflow
2008
+ if (curDepth > maxDepth)
2009
+ return failure ();
2010
+
2011
+ forwardSlice.push_back (
2012
+ cast<OffsetSizeAndStrideOpInterface>(candidateSliceOp));
2013
+ Value resultOfLoop;
2014
+ if (auto sliceOp =
2015
+ dyn_cast<tensor::ParallelInsertSliceOp>(candidateSliceOp)) {
2016
+ Value destValue = sliceOp.getDest ();
2017
+ auto iterArg = cast<BlockArgument>(destValue);
2018
+ auto forallOp = dyn_cast<scf::ForallOp>(iterArg.getOwner ()->getParentOp ());
2019
+ if (!forallOp)
2020
+ return failure ();
2021
+ resultOfLoop = forallOp.getTiedOpResult (forallOp.getTiedOpOperand (iterArg));
2022
+ } else if (auto sliceOp = dyn_cast<tensor::InsertSliceOp>(candidateSliceOp)) {
2023
+ Value resultValue = sliceOp.getResult ();
2024
+ for (auto &useOperand : resultValue.getUses ()) {
2025
+ if (auto yieldOp = dyn_cast<scf::YieldOp>(useOperand.getOwner ())) {
2026
+ if (llvm::detail::isPresent (resultOfLoop))
2027
+ return failure ();
2028
+ auto forOp = dyn_cast<LoopLikeOpInterface>(yieldOp->getParentOp ());
2029
+ if (!forOp)
2030
+ return failure ();
2031
+ resultOfLoop = forOp->getResult (useOperand.getOperandNumber ());
2032
+ }
2033
+ }
2034
+ }
2035
+
2036
+ if (!llvm::detail::isPresent (resultOfLoop))
2037
+ return failure ();
2038
+
2039
+ bool traverseUpperLoop;
2040
+ do {
2041
+ traverseUpperLoop = false ;
2042
+ for (OpOperand &useOperand : resultOfLoop.getUses ()) {
2043
+ if (auto sliceOp =
2044
+ dyn_cast<OffsetSizeAndStrideOpInterface>(useOperand.getOwner ())) {
2045
+ return getRealConsumersFromInsertSliceOp (sliceOp, forwardSlice,
2046
+ curDepth + 1 );
2047
+ }
2048
+ if (auto yieldOp = dyn_cast<scf::YieldOp>(useOperand.getOwner ())) {
2049
+ // Walk through outer loop.
2050
+ auto forOp = dyn_cast<LoopLikeOpInterface>(yieldOp->getParentOp ());
2051
+ if (!forOp)
2052
+ return failure ();
2053
+ resultOfLoop = forOp->getResult (useOperand.getOperandNumber ());
2054
+ traverseUpperLoop = true ;
2055
+ break ;
2056
+ }
2057
+ }
2058
+ } while (traverseUpperLoop);
2059
+ // Return all operands using result of top level loop.
2060
+ return llvm::map_to_vector (resultOfLoop.getUses (),
2061
+ [](OpOperand &u) -> OpOperand * { return &u; });
2062
+ }
2063
+
2064
+ // / Fusing real consumer of a single slice even within complex nested loops via
2065
+ // / multiple application of `tileAndFuseConsumerOfSliceImpl`.
2066
+ FailureOr<scf::SCFFuseConsumerOfSliceResult>
2067
+ mlir::scf::tileAndFuseConsumerOfSlice (RewriterBase &rewriter,
2068
+ Operation *candidateSliceOp) {
2069
+ SmallVector<OffsetSizeAndStrideOpInterface> forwardSlice;
2070
+ if (failed (getRealConsumersFromInsertSliceOp (candidateSliceOp, forwardSlice)))
2071
+ return failure ();
2072
+
2073
+ FailureOr<scf::SCFFuseConsumerOfSliceResult> fuseConsumerResult;
2074
+ // Reverse forward slice from outer to inner.
2075
+ std::reverse (forwardSlice.begin (), forwardSlice.end ());
2076
+ // Multiple application of `tileAndFuseConsumerOfSliceImpl`.
2077
+ for (auto &sliceOp : forwardSlice) {
2078
+ fuseConsumerResult = tileAndFuseConsumerOfSliceImpl (rewriter, sliceOp);
2079
+ if (failed (fuseConsumerResult))
2080
+ return rewriter.notifyMatchFailure (sliceOp,
2081
+ " could not fuse consumer of sliceOp" );
2082
+ }
2083
+ return fuseConsumerResult;
2084
+ }
2085
+
1838
2086
// ===----------------------------------------------------------------------===//
1839
2087
// lowerToLoopsUsingSCFForOp implementation.
1840
2088
// ===----------------------------------------------------------------------===//
0 commit comments