@@ -1846,11 +1846,9 @@ static FailureOr<OpOperand *> getConsumerFromLoopUses(RewriterBase &rewriter,
1846
1846
return failure ();
1847
1847
}
1848
1848
1849
- // / Find the perfectly nested loops outside of given loop(included) sorted
1850
- // / from outer to inner.
1851
- // /
1852
- // / E.g.
1853
- // /
1849
+ // / Check that the loop is perfectly nested.
1850
+ // / The loops are expected to be ordered from outer most to inner most.
1851
+ // / For example:
1854
1852
// / ```
1855
1853
// / %0 = scf.for()
1856
1854
// / %1 = scf.for()
@@ -1860,55 +1858,85 @@ static FailureOr<OpOperand *> getConsumerFromLoopUses(RewriterBase &rewriter,
1860
1858
// / yield %2
1861
1859
// / yield %1
1862
1860
// / ```
1863
- // /
1864
- // / This function will return three perfectly nested loops: %0 + %1 + %2, when
1865
- // / target inner loop is %2.
1866
- static SmallVector<scf::ForOp>
1867
- getPerfectlyNestedLoopsOutsideOf (scf::ForOp loop) {
1868
- SmallVector<scf::ForOp> nestLoops = {loop};
1869
- auto outerLoop = dyn_cast<scf::ForOp>(loop->getParentOp ());
1870
-
1871
- // Check if it is the ForOp that yield the result of inner loop.
1872
- auto isForOpYieldResultOfInnerLoop =
1873
- [](scf::ForOp outerLoop) -> LogicalResult {
1874
- Block *body = outerLoop.getBody ();
1875
- if (!llvm::hasSingleElement (body->without_terminator ()))
1876
- return failure ();
1877
- auto yieldOp = cast<scf::YieldOp>(body->getTerminator ());
1878
- auto innerForOp = dyn_cast<scf::ForOp>(body->front ());
1879
- if (!innerForOp)
1880
- return failure ();
1881
- // All of innerForOp results should be yielded.
1882
- return success (innerForOp->getNumResults () == yieldOp->getNumOperands ());
1883
- };
1861
+ // / Here loops should be [%0, %1].
1862
+ static bool
1863
+ isPerfectlyNestedForLoops (MutableArrayRef<LoopLikeOpInterface> loops) {
1864
+ assert (!loops.empty () && " unexpected empty loop nest" );
1865
+ if (loops.size () == 1 ) {
1866
+ return isa_and_nonnull<scf::ForOp>(loops.front ().getOperation ());
1867
+ }
1868
+ for (auto [outerLoop, innerLoop] :
1869
+ llvm::zip_equal (loops.drop_back (), loops.drop_front ())) {
1870
+ auto outerFor = dyn_cast_or_null<scf::ForOp>(outerLoop.getOperation ());
1871
+ auto innerFor = dyn_cast_or_null<scf::ForOp>(innerLoop.getOperation ());
1872
+ if (!outerFor || !innerFor) {
1873
+ return false ;
1874
+ }
1875
+ auto outerBBArgs = outerFor.getRegionIterArgs ();
1876
+ auto innerIterArgs = innerFor.getInitArgs ();
1877
+ if (outerBBArgs.size () != innerIterArgs.size ()) {
1878
+ return false ;
1879
+ }
1880
+
1881
+ for (auto [outerBBArg, innerIterArg] :
1882
+ llvm::zip_equal (outerBBArgs, innerIterArgs)) {
1883
+ if (!llvm::hasSingleElement (outerBBArg.getUses ()) ||
1884
+ innerIterArg != outerBBArg) {
1885
+ return false ;
1886
+ }
1887
+ }
1884
1888
1885
- while (outerLoop && succeeded (isForOpYieldResultOfInnerLoop (outerLoop))) {
1886
- nestLoops.push_back (outerLoop);
1887
- outerLoop = dyn_cast<scf::ForOp>(outerLoop->getParentOp ());
1889
+ ValueRange outerYields =
1890
+ cast<scf::YieldOp>(outerFor.getBody ()->getTerminator ())->getOperands ();
1891
+ ValueRange innerResults = innerFor.getResults ();
1892
+ if (outerYields.size () != innerResults.size ()) {
1893
+ return false ;
1894
+ }
1895
+ for (auto [outerYield, innerResult] :
1896
+ llvm::zip_equal (outerYields, innerResults)) {
1897
+ if (!llvm::hasSingleElement (innerResult.getUses ()) ||
1898
+ outerYield != innerResult) {
1899
+ return false ;
1900
+ }
1901
+ }
1888
1902
}
1889
- // sorted from outer to inner
1890
- return {nestLoops.rbegin (), nestLoops.rend ()};
1903
+ return true ;
1891
1904
}
1892
1905
1893
- // / Fetch the untiled consumer of a scf.for's result which is yielded by a
1894
- // / tensor.insert_slice. This function makes the following assumptions :
1895
- // / 1. tensor.insert_slice has scf.yield as its only user.
1896
- // / 2. scf.for's corresponding result has only one use.
1906
+ // / Fetch the untiled consumer of the outermost scf.for's result which is
1907
+ // / yielded by a tensor.insert_slice from the innermost scf.for. This function
1908
+ // / makes the following assumptions :
1909
+ // / 1. tensor.insert_slice has scf.yield as its only user.
1910
+ // / 2. scf.for's corresponding result has only one use.
1911
+ // / 3. The `loops` passed in are perfectly nested `scf.for` operations.
1897
1912
static FailureOr<OpOperand *>
1898
1913
getUntiledConsumerFromSlice (RewriterBase &rewriter,
1899
- tensor::InsertSliceOp candidateSliceOp) {
1914
+ tensor::InsertSliceOp candidateSliceOp,
1915
+ MutableArrayRef<LoopLikeOpInterface> loops) {
1916
+ assert (!loops.empty () && " unexpected loops to be empty" );
1917
+ // 1. Expect slice to be part of the body of the inner most loop.
1918
+ Operation *containingOp = candidateSliceOp->getParentOp ();
1919
+ if (containingOp != loops.back ()) {
1920
+ return rewriter.notifyMatchFailure (
1921
+ candidateSliceOp,
1922
+ " expected slice to be within body of inner-most loop" );
1923
+ }
1924
+
1925
+ // 2. Check that the loop is perfectly nested.
1926
+ if (!isPerfectlyNestedForLoops (loops)) {
1927
+ return rewriter.notifyMatchFailure (
1928
+ candidateSliceOp, " expected passed loops to be perfectly nested." );
1929
+ }
1930
+
1900
1931
if (failed (checkAssumptionForFusingConsumer (candidateSliceOp)))
1901
1932
return failure ();
1902
1933
Value sliceResult = candidateSliceOp.getResult ();
1903
- // Step 1. Fetch the corresponding output.
1934
+
1935
+ // 3. Fetch the corresponding output.
1904
1936
OpOperand &yieldOpOperand = (*sliceResult.getUses ().begin ());
1905
1937
unsigned resultNumber = yieldOpOperand.getOperandNumber ();
1906
- // Step 2. Check containing op is scf.for.
1907
- Operation *containingOp = candidateSliceOp->getParentOp ();
1908
- auto forOp = dyn_cast<scf::ForOp>(containingOp);
1909
- if (!forOp)
1910
- return failure ();
1911
- scf::ForOp topLevelForOp = getPerfectlyNestedLoopsOutsideOf (forOp).front ();
1938
+
1939
+ scf::ForOp topLevelForOp = cast<scf::ForOp>(loops.front ().getOperation ());
1912
1940
1913
1941
return getConsumerFromLoopUses (rewriter, topLevelForOp, resultNumber);
1914
1942
}
@@ -1917,35 +1945,46 @@ getUntiledConsumerFromSlice(RewriterBase &rewriter,
1917
1945
// / by a tensor.parallel_insert_slice.
1918
1946
static FailureOr<OpOperand *>
1919
1947
getUntiledConsumerFromSlice (RewriterBase &rewriter,
1920
- tensor::ParallelInsertSliceOp candidateSliceOp) {
1921
- // Step 1. Fetch the corresponding output
1948
+ tensor::ParallelInsertSliceOp candidateSliceOp,
1949
+ MutableArrayRef<LoopLikeOpInterface> loops) {
1950
+ assert (!loops.empty () && " unexpected loops to be empty" );
1951
+ // 1. Check that the surrounding loop is a single scf.forall loop.
1952
+ if (loops.size () != 1 ) {
1953
+ return rewriter.notifyMatchFailure (
1954
+ candidateSliceOp, " expected single surrounding scf.forall" );
1955
+ }
1956
+ auto forallOp = dyn_cast<scf::ForallOp>(loops.front ().getOperation ());
1957
+ if (!forallOp) {
1958
+ return rewriter.notifyMatchFailure (
1959
+ candidateSliceOp, " expected single surrounding scf.forall" );
1960
+ }
1961
+
1962
+ // 2. Fetch the corresponding output
1922
1963
Value sliceDest = candidateSliceOp.getDest ();
1923
1964
auto iterArg = dyn_cast<BlockArgument>(sliceDest);
1924
1965
if (!iterArg)
1925
1966
return failure ();
1926
- Operation *containingOp = iterArg.getOwner ()->getParentOp ();
1927
- if (containingOp != candidateSliceOp->getParentOp ()->getParentOp ())
1928
- return failure ();
1929
- // Step 2. Check that the containing op is scf.forall.
1930
- auto forallOp = dyn_cast<scf::ForallOp>(containingOp);
1931
- if (!forallOp)
1967
+ if (iterArg.getOwner ()->getParentOp () != forallOp)
1932
1968
return failure ();
1969
+
1933
1970
unsigned resultNumber =
1934
1971
forallOp.getTiedOpResult (forallOp.getTiedOpOperand (iterArg))
1935
1972
.getResultNumber ();
1936
1973
1937
- return getConsumerFromLoopUses (rewriter, containingOp , resultNumber);
1974
+ return getConsumerFromLoopUses (rewriter, forallOp , resultNumber);
1938
1975
}
1939
1976
1940
1977
// / A utility to fetch an untiled consumer of
1941
1978
// / tensor.insert_slice/tensor.parallel_insert_slice.
1942
1979
static FailureOr<OpOperand *>
1943
- getUntiledConsumerFromSlice (RewriterBase &rewriter, Operation *sliceOp) {
1980
+ getUntiledConsumerFromSlice (RewriterBase &rewriter, Operation *sliceOp,
1981
+ MutableArrayRef<LoopLikeOpInterface> loops) {
1982
+ assert (!loops.empty () && " unexpected empty loops" );
1944
1983
if (auto insertSlice = dyn_cast<tensor::InsertSliceOp>(sliceOp)) {
1945
- return getUntiledConsumerFromSlice (rewriter, insertSlice);
1984
+ return getUntiledConsumerFromSlice (rewriter, insertSlice, loops );
1946
1985
} else if (auto parallelInsertSlice =
1947
1986
dyn_cast<tensor::ParallelInsertSliceOp>(sliceOp)) {
1948
- return getUntiledConsumerFromSlice (rewriter, parallelInsertSlice);
1987
+ return getUntiledConsumerFromSlice (rewriter, parallelInsertSlice, loops );
1949
1988
} else {
1950
1989
return failure ();
1951
1990
}
@@ -1954,18 +1993,23 @@ getUntiledConsumerFromSlice(RewriterBase &rewriter, Operation *sliceOp) {
1954
1993
// / Implementation of fusing consumer of a single slice by computing the
1955
1994
// / slice of the consumer in-place for scf loop.
1956
1995
FailureOr<scf::SCFFuseConsumerOfSliceResult>
1957
- mlir::scf::tileAndFuseConsumerOfSlice (RewriterBase &rewriter,
1958
- Operation *candidateSliceOp) {
1996
+ mlir::scf::tileAndFuseConsumerOfSlice (
1997
+ RewriterBase &rewriter, Operation *candidateSliceOp,
1998
+ MutableArrayRef<LoopLikeOpInterface> loops) {
1999
+ // Return if `loops` is empty, return an error for now. Caller is expected
2000
+ // to handle this case.
2001
+ if (loops.empty ()) {
2002
+ return candidateSliceOp->emitOpError (
2003
+ " cannot call tile and fuse consumer with an empty loop nest" );
2004
+ }
1959
2005
if (!isa<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
1960
2006
candidateSliceOp))
1961
2007
return failure ();
1962
2008
1963
- bool isInsertSliceOp = isa<tensor::InsertSliceOp>(candidateSliceOp);
1964
-
1965
2009
// 1. Get the consumer of scf.for for the result yielded by
1966
2010
// tensor.insert_slice/parallel_insert_slice.
1967
2011
FailureOr<OpOperand *> maybeConsumerOpOperand =
1968
- getUntiledConsumerFromSlice (rewriter, candidateSliceOp);
2012
+ getUntiledConsumerFromSlice (rewriter, candidateSliceOp, loops );
1969
2013
if (failed (maybeConsumerOpOperand)) {
1970
2014
return rewriter.notifyMatchFailure (candidateSliceOp,
1971
2015
" could not fetch consumer to fuse" );
@@ -1981,25 +2025,8 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
1981
2025
consumerOp, " consumer op's operand doesn't seem to be an OpResult" );
1982
2026
}
1983
2027
1984
- // There are two possible cases regarding `oldLoopOp` here:
1985
- // 1. single `scf.forall` or `scf.for`.
1986
- // 2. inner-most `scf.for` insider nest `scf.loop` structure, where the
1987
- // top-level loop is the outer-most one of these nested loops.
1988
- LoopLikeOpInterface innerMostLoop =
1989
- candidateSliceOp->getParentOfType <LoopLikeOpInterface>();
1990
- SmallVector<LoopLikeOpInterface> nestedLoops;
1991
- if (isInsertSliceOp) {
1992
- nestedLoops = llvm::map_to_vector (
1993
- getPerfectlyNestedLoopsOutsideOf (
1994
- cast<scf::ForOp>(innerMostLoop.getOperation ())),
1995
- [](scf::ForOp forOp) {
1996
- return cast<LoopLikeOpInterface>(forOp.getOperation ());
1997
- });
1998
- } else {
1999
- nestedLoops = {innerMostLoop};
2000
- }
2001
-
2002
- LoopLikeOpInterface outerMostLoop = nestedLoops.front ();
2028
+ LoopLikeOpInterface outerMostLoop = loops.front ();
2029
+ LoopLikeOpInterface innerMostLoop = loops.back ();
2003
2030
2004
2031
// Check assumption for loop with `reorderOperations` disabled.
2005
2032
if (failed (checkAssumptionForLoop (outerMostLoop, consumerOp, false ))) {
@@ -2165,7 +2192,7 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
2165
2192
return success ();
2166
2193
};
2167
2194
// 14. Add new inits to [nested] loops.
2168
- if (failed (addInitOperandsToLoopNest (rewriter, nestedLoops , newInits,
2195
+ if (failed (addInitOperandsToLoopNest (rewriter, loops , newInits,
2169
2196
newYieldValuesFn))) {
2170
2197
return rewriter.notifyMatchFailure (tiledConsumerOp,
2171
2198
" unable to add new inits to nest loop" );
@@ -2174,9 +2201,9 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
2174
2201
// 15. Replace the result of scf loop and consumer op with new loop's
2175
2202
// results.
2176
2203
2177
- for (auto &&[oldResult, newResult] : llvm::zip (
2178
- consumerOp->getResults (),
2179
- nestedLoops .front ()->getResults ().take_back (newInits.size ()))) {
2204
+ for (auto &&[oldResult, newResult] :
2205
+ llvm::zip ( consumerOp->getResults (),
2206
+ loops .front ()->getResults ().take_back (newInits.size ()))) {
2180
2207
rewriter.replaceAllUsesWith (oldResult, newResult);
2181
2208
}
2182
2209
0 commit comments