Skip to content

Commit e417219

Browse files
[mlir][TilingInterface] Make tileAndFuseConsumerOfSlice take surrounding loops as an argument. (#132082)
This gets the consumer fusion method in sync with the corresponding producer fusion method `tileAndFuseProducerOfSlice`. Not taking this as input required use of complicated analysis to retrieve the surrounding loops which are very fragile. Just like the producer fusion method, the loops need to be taken in as an argument, with typically the loops being created by the tiling methods. Some utilities are added to check that the loops passed in are perfectly nested (in the case of an `scf.for` loop nest. This is change 1 of N to simplify the implementation of tile and fuse consumers. --------- Signed-off-by: MaheshRavishankar <[email protected]>
1 parent 8d3dc1e commit e417219

File tree

6 files changed

+175
-253
lines changed

6 files changed

+175
-253
lines changed

mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,8 @@ struct SCFFuseConsumerOfSliceResult {
328328
SmallVector<Operation *> tiledOps;
329329
};
330330
FailureOr<scf::SCFFuseConsumerOfSliceResult>
331-
tileAndFuseConsumerOfSlice(RewriterBase &rewriter, Operation *candidateSliceOp);
331+
tileAndFuseConsumerOfSlice(RewriterBase &rewriter, Operation *candidateSliceOp,
332+
MutableArrayRef<LoopLikeOpInterface> loops);
332333

333334
/// Method to lower an `op` that implements the `TilingInterface` to
334335
/// loops/scalars.

mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp

Lines changed: 110 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -1846,11 +1846,9 @@ static FailureOr<OpOperand *> getConsumerFromLoopUses(RewriterBase &rewriter,
18461846
return failure();
18471847
}
18481848

1849-
/// Find the perfectly nested loops outside of given loop(included) sorted
1850-
/// from outer to inner.
1851-
///
1852-
/// E.g.
1853-
///
1849+
/// Check that the loop is perfectly nested.
1850+
/// The loops are expected to be ordered from outer most to inner most.
1851+
/// For example:
18541852
/// ```
18551853
/// %0 = scf.for()
18561854
/// %1 = scf.for()
@@ -1860,55 +1858,85 @@ static FailureOr<OpOperand *> getConsumerFromLoopUses(RewriterBase &rewriter,
18601858
/// yield %2
18611859
/// yield %1
18621860
/// ```
1863-
///
1864-
/// This function will return three perfectly nested loops: %0 + %1 + %2, when
1865-
/// target inner loop is %2.
1866-
static SmallVector<scf::ForOp>
1867-
getPerfectlyNestedLoopsOutsideOf(scf::ForOp loop) {
1868-
SmallVector<scf::ForOp> nestLoops = {loop};
1869-
auto outerLoop = dyn_cast<scf::ForOp>(loop->getParentOp());
1870-
1871-
// Check if it is the ForOp that yield the result of inner loop.
1872-
auto isForOpYieldResultOfInnerLoop =
1873-
[](scf::ForOp outerLoop) -> LogicalResult {
1874-
Block *body = outerLoop.getBody();
1875-
if (!llvm::hasSingleElement(body->without_terminator()))
1876-
return failure();
1877-
auto yieldOp = cast<scf::YieldOp>(body->getTerminator());
1878-
auto innerForOp = dyn_cast<scf::ForOp>(body->front());
1879-
if (!innerForOp)
1880-
return failure();
1881-
// All of innerForOp results should be yielded.
1882-
return success(innerForOp->getNumResults() == yieldOp->getNumOperands());
1883-
};
1861+
/// Here loops should be [%0, %1].
1862+
static bool
1863+
isPerfectlyNestedForLoops(MutableArrayRef<LoopLikeOpInterface> loops) {
1864+
assert(!loops.empty() && "unexpected empty loop nest");
1865+
if (loops.size() == 1) {
1866+
return isa_and_nonnull<scf::ForOp>(loops.front().getOperation());
1867+
}
1868+
for (auto [outerLoop, innerLoop] :
1869+
llvm::zip_equal(loops.drop_back(), loops.drop_front())) {
1870+
auto outerFor = dyn_cast_or_null<scf::ForOp>(outerLoop.getOperation());
1871+
auto innerFor = dyn_cast_or_null<scf::ForOp>(innerLoop.getOperation());
1872+
if (!outerFor || !innerFor) {
1873+
return false;
1874+
}
1875+
auto outerBBArgs = outerFor.getRegionIterArgs();
1876+
auto innerIterArgs = innerFor.getInitArgs();
1877+
if (outerBBArgs.size() != innerIterArgs.size()) {
1878+
return false;
1879+
}
1880+
1881+
for (auto [outerBBArg, innerIterArg] :
1882+
llvm::zip_equal(outerBBArgs, innerIterArgs)) {
1883+
if (!llvm::hasSingleElement(outerBBArg.getUses()) ||
1884+
innerIterArg != outerBBArg) {
1885+
return false;
1886+
}
1887+
}
18841888

1885-
while (outerLoop && succeeded(isForOpYieldResultOfInnerLoop(outerLoop))) {
1886-
nestLoops.push_back(outerLoop);
1887-
outerLoop = dyn_cast<scf::ForOp>(outerLoop->getParentOp());
1889+
ValueRange outerYields =
1890+
cast<scf::YieldOp>(outerFor.getBody()->getTerminator())->getOperands();
1891+
ValueRange innerResults = innerFor.getResults();
1892+
if (outerYields.size() != innerResults.size()) {
1893+
return false;
1894+
}
1895+
for (auto [outerYield, innerResult] :
1896+
llvm::zip_equal(outerYields, innerResults)) {
1897+
if (!llvm::hasSingleElement(innerResult.getUses()) ||
1898+
outerYield != innerResult) {
1899+
return false;
1900+
}
1901+
}
18881902
}
1889-
// sorted from outer to inner
1890-
return {nestLoops.rbegin(), nestLoops.rend()};
1903+
return true;
18911904
}
18921905

1893-
/// Fetch the untiled consumer of a scf.for's result which is yielded by a
1894-
/// tensor.insert_slice. This function makes the following assumptions :
1895-
/// 1. tensor.insert_slice has scf.yield as its only user.
1896-
/// 2. scf.for's corresponding result has only one use.
1906+
/// Fetch the untiled consumer of the outermost scf.for's result which is
1907+
/// yielded by a tensor.insert_slice from the innermost scf.for. This function
1908+
/// makes the following assumptions :
1909+
/// 1. tensor.insert_slice has scf.yield as its only user.
1910+
/// 2. scf.for's corresponding result has only one use.
1911+
/// 3. The `loops` passed in are perfectly nested `scf.for` operations.
18971912
static FailureOr<OpOperand *>
18981913
getUntiledConsumerFromSlice(RewriterBase &rewriter,
1899-
tensor::InsertSliceOp candidateSliceOp) {
1914+
tensor::InsertSliceOp candidateSliceOp,
1915+
MutableArrayRef<LoopLikeOpInterface> loops) {
1916+
assert(!loops.empty() && "unexpected loops to be empty");
1917+
// 1. Expect slice to be part of the body of the inner most loop.
1918+
Operation *containingOp = candidateSliceOp->getParentOp();
1919+
if (containingOp != loops.back()) {
1920+
return rewriter.notifyMatchFailure(
1921+
candidateSliceOp,
1922+
"expected slice to be within body of inner-most loop");
1923+
}
1924+
1925+
// 2. Check that the loop is perfectly nested.
1926+
if (!isPerfectlyNestedForLoops(loops)) {
1927+
return rewriter.notifyMatchFailure(
1928+
candidateSliceOp, "expected passed loops to be perfectly nested.");
1929+
}
1930+
19001931
if (failed(checkAssumptionForFusingConsumer(candidateSliceOp)))
19011932
return failure();
19021933
Value sliceResult = candidateSliceOp.getResult();
1903-
// Step 1. Fetch the corresponding output.
1934+
1935+
// 3. Fetch the corresponding output.
19041936
OpOperand &yieldOpOperand = (*sliceResult.getUses().begin());
19051937
unsigned resultNumber = yieldOpOperand.getOperandNumber();
1906-
// Step 2. Check containing op is scf.for.
1907-
Operation *containingOp = candidateSliceOp->getParentOp();
1908-
auto forOp = dyn_cast<scf::ForOp>(containingOp);
1909-
if (!forOp)
1910-
return failure();
1911-
scf::ForOp topLevelForOp = getPerfectlyNestedLoopsOutsideOf(forOp).front();
1938+
1939+
scf::ForOp topLevelForOp = cast<scf::ForOp>(loops.front().getOperation());
19121940

19131941
return getConsumerFromLoopUses(rewriter, topLevelForOp, resultNumber);
19141942
}
@@ -1917,35 +1945,46 @@ getUntiledConsumerFromSlice(RewriterBase &rewriter,
19171945
/// by a tensor.parallel_insert_slice.
19181946
static FailureOr<OpOperand *>
19191947
getUntiledConsumerFromSlice(RewriterBase &rewriter,
1920-
tensor::ParallelInsertSliceOp candidateSliceOp) {
1921-
// Step 1. Fetch the corresponding output
1948+
tensor::ParallelInsertSliceOp candidateSliceOp,
1949+
MutableArrayRef<LoopLikeOpInterface> loops) {
1950+
assert(!loops.empty() && "unexpected loops to be empty");
1951+
// 1. Check that the surrounding loop is a single scf.forall loop.
1952+
if (loops.size() != 1) {
1953+
return rewriter.notifyMatchFailure(
1954+
candidateSliceOp, "expected single surrounding scf.forall");
1955+
}
1956+
auto forallOp = dyn_cast<scf::ForallOp>(loops.front().getOperation());
1957+
if (!forallOp) {
1958+
return rewriter.notifyMatchFailure(
1959+
candidateSliceOp, "expected single surrounding scf.forall");
1960+
}
1961+
1962+
// 2. Fetch the corresponding output
19221963
Value sliceDest = candidateSliceOp.getDest();
19231964
auto iterArg = dyn_cast<BlockArgument>(sliceDest);
19241965
if (!iterArg)
19251966
return failure();
1926-
Operation *containingOp = iterArg.getOwner()->getParentOp();
1927-
if (containingOp != candidateSliceOp->getParentOp()->getParentOp())
1928-
return failure();
1929-
// Step 2. Check that the containing op is scf.forall.
1930-
auto forallOp = dyn_cast<scf::ForallOp>(containingOp);
1931-
if (!forallOp)
1967+
if (iterArg.getOwner()->getParentOp() != forallOp)
19321968
return failure();
1969+
19331970
unsigned resultNumber =
19341971
forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg))
19351972
.getResultNumber();
19361973

1937-
return getConsumerFromLoopUses(rewriter, containingOp, resultNumber);
1974+
return getConsumerFromLoopUses(rewriter, forallOp, resultNumber);
19381975
}
19391976

19401977
/// A utility to fetch an untiled consumer of
19411978
/// tensor.insert_slice/tensor.parallel_insert_slice.
19421979
static FailureOr<OpOperand *>
1943-
getUntiledConsumerFromSlice(RewriterBase &rewriter, Operation *sliceOp) {
1980+
getUntiledConsumerFromSlice(RewriterBase &rewriter, Operation *sliceOp,
1981+
MutableArrayRef<LoopLikeOpInterface> loops) {
1982+
assert(!loops.empty() && "unexpected empty loops");
19441983
if (auto insertSlice = dyn_cast<tensor::InsertSliceOp>(sliceOp)) {
1945-
return getUntiledConsumerFromSlice(rewriter, insertSlice);
1984+
return getUntiledConsumerFromSlice(rewriter, insertSlice, loops);
19461985
} else if (auto parallelInsertSlice =
19471986
dyn_cast<tensor::ParallelInsertSliceOp>(sliceOp)) {
1948-
return getUntiledConsumerFromSlice(rewriter, parallelInsertSlice);
1987+
return getUntiledConsumerFromSlice(rewriter, parallelInsertSlice, loops);
19491988
} else {
19501989
return failure();
19511990
}
@@ -1954,18 +1993,23 @@ getUntiledConsumerFromSlice(RewriterBase &rewriter, Operation *sliceOp) {
19541993
/// Implementation of fusing consumer of a single slice by computing the
19551994
/// slice of the consumer in-place for scf loop.
19561995
FailureOr<scf::SCFFuseConsumerOfSliceResult>
1957-
mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
1958-
Operation *candidateSliceOp) {
1996+
mlir::scf::tileAndFuseConsumerOfSlice(
1997+
RewriterBase &rewriter, Operation *candidateSliceOp,
1998+
MutableArrayRef<LoopLikeOpInterface> loops) {
1999+
// Return if `loops` is empty, return an error for now. Caller is expected
2000+
// to handle this case.
2001+
if (loops.empty()) {
2002+
return candidateSliceOp->emitOpError(
2003+
"cannot call tile and fuse consumer with an empty loop nest");
2004+
}
19592005
if (!isa<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
19602006
candidateSliceOp))
19612007
return failure();
19622008

1963-
bool isInsertSliceOp = isa<tensor::InsertSliceOp>(candidateSliceOp);
1964-
19652009
// 1. Get the consumer of scf.for for the result yielded by
19662010
// tensor.insert_slice/parallel_insert_slice.
19672011
FailureOr<OpOperand *> maybeConsumerOpOperand =
1968-
getUntiledConsumerFromSlice(rewriter, candidateSliceOp);
2012+
getUntiledConsumerFromSlice(rewriter, candidateSliceOp, loops);
19692013
if (failed(maybeConsumerOpOperand)) {
19702014
return rewriter.notifyMatchFailure(candidateSliceOp,
19712015
"could not fetch consumer to fuse");
@@ -1981,25 +2025,8 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
19812025
consumerOp, "consumer op's operand doesn't seem to be an OpResult");
19822026
}
19832027

1984-
// There are two possible cases regarding `oldLoopOp` here:
1985-
// 1. single `scf.forall` or `scf.for`.
1986-
// 2. inner-most `scf.for` insider nest `scf.loop` structure, where the
1987-
// top-level loop is the outer-most one of these nested loops.
1988-
LoopLikeOpInterface innerMostLoop =
1989-
candidateSliceOp->getParentOfType<LoopLikeOpInterface>();
1990-
SmallVector<LoopLikeOpInterface> nestedLoops;
1991-
if (isInsertSliceOp) {
1992-
nestedLoops = llvm::map_to_vector(
1993-
getPerfectlyNestedLoopsOutsideOf(
1994-
cast<scf::ForOp>(innerMostLoop.getOperation())),
1995-
[](scf::ForOp forOp) {
1996-
return cast<LoopLikeOpInterface>(forOp.getOperation());
1997-
});
1998-
} else {
1999-
nestedLoops = {innerMostLoop};
2000-
}
2001-
2002-
LoopLikeOpInterface outerMostLoop = nestedLoops.front();
2028+
LoopLikeOpInterface outerMostLoop = loops.front();
2029+
LoopLikeOpInterface innerMostLoop = loops.back();
20032030

20042031
// Check assumption for loop with `reorderOperations` disabled.
20052032
if (failed(checkAssumptionForLoop(outerMostLoop, consumerOp, false))) {
@@ -2165,7 +2192,7 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
21652192
return success();
21662193
};
21672194
// 14. Add new inits to [nested] loops.
2168-
if (failed(addInitOperandsToLoopNest(rewriter, nestedLoops, newInits,
2195+
if (failed(addInitOperandsToLoopNest(rewriter, loops, newInits,
21692196
newYieldValuesFn))) {
21702197
return rewriter.notifyMatchFailure(tiledConsumerOp,
21712198
"unable to add new inits to nest loop");
@@ -2174,9 +2201,9 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
21742201
// 15. Replace the result of scf loop and consumer op with new loop's
21752202
// results.
21762203

2177-
for (auto &&[oldResult, newResult] : llvm::zip(
2178-
consumerOp->getResults(),
2179-
nestedLoops.front()->getResults().take_back(newInits.size()))) {
2204+
for (auto &&[oldResult, newResult] :
2205+
llvm::zip(consumerOp->getResults(),
2206+
loops.front()->getResults().take_back(newInits.size()))) {
21802207
rewriter.replaceAllUsesWith(oldResult, newResult);
21812208
}
21822209

mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,8 +170,8 @@ module {
170170
// Fuse the consumer operation into the tiled loop.
171171
%slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %forall_op
172172
: (!transform.any_op) -> !transform.op<"tensor.parallel_insert_slice">
173-
transform.test.fuse_consumer %slice_op
174-
: (!transform.op<"tensor.parallel_insert_slice">) -> (!transform.any_op, !transform.any_op)
173+
transform.test.fuse_consumer %slice_op in (%forall_op)
174+
: (!transform.op<"tensor.parallel_insert_slice">, !transform.any_op) -> (!transform.any_op, !transform.any_op)
175175
transform.yield
176176
}
177177
}

0 commit comments

Comments
 (0)