Skip to content

Commit 5aeb604

Browse files
[mlir][SCF] Modernize coalesceLoops method to handle scf.for loops with iter_args (#87019)
As part of this extension this change also does some general cleanup 1) Make all the methods take `RewriterBase` as arguments instead of creating their own builders that tend to crash when used within pattern rewrites 2) Split `coalesePerfectlyNestedLoops` into two separate methods, one for `scf.for` and other for `affine.for`. The templatization didnt seem to be buying much there. Also general clean up of tests.
1 parent fd2a5c4 commit 5aeb604

File tree

13 files changed

+587
-240
lines changed

13 files changed

+587
-240
lines changed

mlir/include/mlir/Dialect/Affine/LoopUtils.h

Lines changed: 2 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -299,53 +299,8 @@ LogicalResult
299299
separateFullTiles(MutableArrayRef<AffineForOp> nest,
300300
SmallVectorImpl<AffineForOp> *fullTileNest = nullptr);
301301

302-
/// Walk either an scf.for or an affine.for to find a band to coalesce.
303-
template <typename LoopOpTy>
304-
LogicalResult coalescePerfectlyNestedLoops(LoopOpTy op) {
305-
LogicalResult result(failure());
306-
SmallVector<LoopOpTy> loops;
307-
getPerfectlyNestedLoops(loops, op);
308-
309-
// Look for a band of loops that can be coalesced, i.e. perfectly nested
310-
// loops with bounds defined above some loop.
311-
// 1. For each loop, find above which parent loop its operands are
312-
// defined.
313-
SmallVector<unsigned, 4> operandsDefinedAbove(loops.size());
314-
for (unsigned i = 0, e = loops.size(); i < e; ++i) {
315-
operandsDefinedAbove[i] = i;
316-
for (unsigned j = 0; j < i; ++j) {
317-
if (areValuesDefinedAbove(loops[i].getOperands(), loops[j].getRegion())) {
318-
operandsDefinedAbove[i] = j;
319-
break;
320-
}
321-
}
322-
}
323-
324-
// 2. Identify bands of loops such that the operands of all of them are
325-
// defined above the first loop in the band. Traverse the nest bottom-up
326-
// so that modifications don't invalidate the inner loops.
327-
for (unsigned end = loops.size(); end > 0; --end) {
328-
unsigned start = 0;
329-
for (; start < end - 1; ++start) {
330-
auto maxPos =
331-
*std::max_element(std::next(operandsDefinedAbove.begin(), start),
332-
std::next(operandsDefinedAbove.begin(), end));
333-
if (maxPos > start)
334-
continue;
335-
assert(maxPos == start &&
336-
"expected loop bounds to be known at the start of the band");
337-
auto band = llvm::MutableArrayRef(loops.data() + start, end - start);
338-
if (succeeded(coalesceLoops(band)))
339-
result = success();
340-
break;
341-
}
342-
// If a band was found and transformed, keep looking at the loops above
343-
// the outermost transformed loop.
344-
if (start != end - 1)
345-
end = start + 1;
346-
}
347-
return result;
348-
}
302+
/// Walk an affine.for to find a band to coalesce.
303+
LogicalResult coalescePerfectlyNestedAffineLoops(AffineForOp op);
349304

350305
} // namespace affine
351306
} // namespace mlir

mlir/include/mlir/Dialect/SCF/Utils/Utils.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,11 +100,16 @@ getSCFMinMaxExpr(Value value, SmallVectorImpl<Value> &dims,
100100
/// `loops` contains a list of perfectly nested loops with bounds and steps
101101
/// independent of any loop induction variable involved in the nest.
102102
LogicalResult coalesceLoops(MutableArrayRef<scf::ForOp> loops);
103+
LogicalResult coalesceLoops(RewriterBase &rewriter,
104+
MutableArrayRef<scf::ForOp>);
105+
106+
/// Walk an affine.for to find a band to coalesce.
107+
LogicalResult coalescePerfectlyNestedSCFForLoops(scf::ForOp op);
103108

104109
/// Take the ParallelLoop and for each set of dimension indices, combine them
105110
/// into a single dimension. combinedDimensions must contain each index into
106111
/// loops exactly once.
107-
void collapseParallelLoops(scf::ParallelOp loops,
112+
void collapseParallelLoops(RewriterBase &rewriter, scf::ParallelOp loops,
108113
ArrayRef<std::vector<unsigned>> combinedDimensions);
109114

110115
/// Unrolls this for operation by the specified unroll factor. Returns failure

mlir/include/mlir/IR/PatternMatch.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "llvm/Support/TypeName.h"
1616
#include <optional>
1717

18+
using llvm::SmallPtrSetImpl;
1819
namespace mlir {
1920

2021
class PatternRewriter;
@@ -704,6 +705,8 @@ class RewriterBase : public OpBuilder {
704705
return user != exceptedUser;
705706
});
706707
}
708+
void replaceAllUsesExcept(Value from, Value to,
709+
const SmallPtrSetImpl<Operation *> &preservedUsers);
707710

708711
/// Used to notify the listener that the IR failed to be rewritten because of
709712
/// a match failure, and provide a callback to populate a diagnostic with the

mlir/lib/Dialect/Affine/Transforms/LoopCoalescing.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,9 @@ struct LoopCoalescingPass
3939
func::FuncOp func = getOperation();
4040
func.walk<WalkOrder::PreOrder>([](Operation *op) {
4141
if (auto scfForOp = dyn_cast<scf::ForOp>(op))
42-
(void)coalescePerfectlyNestedLoops(scfForOp);
42+
(void)coalescePerfectlyNestedSCFForLoops(scfForOp);
4343
else if (auto affineForOp = dyn_cast<AffineForOp>(op))
44-
(void)coalescePerfectlyNestedLoops(affineForOp);
44+
(void)coalescePerfectlyNestedAffineLoops(affineForOp);
4545
});
4646
}
4747
};

mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2765,3 +2765,51 @@ mlir::affine::separateFullTiles(MutableArrayRef<AffineForOp> inputNest,
27652765

27662766
return success();
27672767
}
2768+
2769+
LogicalResult affine::coalescePerfectlyNestedAffineLoops(AffineForOp op) {
2770+
LogicalResult result(failure());
2771+
SmallVector<AffineForOp> loops;
2772+
getPerfectlyNestedLoops(loops, op);
2773+
if (loops.size() <= 1)
2774+
return success();
2775+
2776+
// Look for a band of loops that can be coalesced, i.e. perfectly nested
2777+
// loops with bounds defined above some loop.
2778+
// 1. For each loop, find above which parent loop its operands are
2779+
// defined.
2780+
SmallVector<unsigned> operandsDefinedAbove(loops.size());
2781+
for (unsigned i = 0, e = loops.size(); i < e; ++i) {
2782+
operandsDefinedAbove[i] = i;
2783+
for (unsigned j = 0; j < i; ++j) {
2784+
if (areValuesDefinedAbove(loops[i].getOperands(), loops[j].getRegion())) {
2785+
operandsDefinedAbove[i] = j;
2786+
break;
2787+
}
2788+
}
2789+
}
2790+
2791+
// 2. Identify bands of loops such that the operands of all of them are
2792+
// defined above the first loop in the band. Traverse the nest bottom-up
2793+
// so that modifications don't invalidate the inner loops.
2794+
for (unsigned end = loops.size(); end > 0; --end) {
2795+
unsigned start = 0;
2796+
for (; start < end - 1; ++start) {
2797+
auto maxPos =
2798+
*std::max_element(std::next(operandsDefinedAbove.begin(), start),
2799+
std::next(operandsDefinedAbove.begin(), end));
2800+
if (maxPos > start)
2801+
continue;
2802+
assert(maxPos == start &&
2803+
"expected loop bounds to be known at the start of the band");
2804+
auto band = llvm::MutableArrayRef(loops.data() + start, end - start);
2805+
if (succeeded(coalesceLoops(band)))
2806+
result = success();
2807+
break;
2808+
}
2809+
// If a band was found and transformed, keep looking at the loops above
2810+
// the outermost transformed loop.
2811+
if (start != end - 1)
2812+
end = start + 1;
2813+
}
2814+
return result;
2815+
}

mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -332,9 +332,9 @@ transform::LoopCoalesceOp::applyToOne(transform::TransformRewriter &rewriter,
332332
transform::TransformState &state) {
333333
LogicalResult result(failure());
334334
if (scf::ForOp scfForOp = dyn_cast<scf::ForOp>(op))
335-
result = coalescePerfectlyNestedLoops(scfForOp);
335+
result = coalescePerfectlyNestedSCFForLoops(scfForOp);
336336
else if (AffineForOp affineForOp = dyn_cast<AffineForOp>(op))
337-
result = coalescePerfectlyNestedLoops(affineForOp);
337+
result = coalescePerfectlyNestedAffineLoops(affineForOp);
338338

339339
results.push_back(op);
340340
if (failed(result)) {

mlir/lib/Dialect/SCF/Transforms/ParallelLoopCollapsing.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ namespace {
2828
struct TestSCFParallelLoopCollapsing
2929
: public impl::TestSCFParallelLoopCollapsingBase<
3030
TestSCFParallelLoopCollapsing> {
31+
3132
void runOnOperation() override {
3233
Operation *module = getOperation();
3334

@@ -88,6 +89,7 @@ struct TestSCFParallelLoopCollapsing
8889
// Only apply the transformation on parallel loops where the specified
8990
// transformation is valid, but do NOT early abort in the case of invalid
9091
// loops.
92+
IRRewriter rewriter(&getContext());
9193
module->walk([&](scf::ParallelOp op) {
9294
if (flattenedCombinedLoops.size() != op.getNumLoops()) {
9395
op.emitOpError("has ")
@@ -97,7 +99,7 @@ struct TestSCFParallelLoopCollapsing
9799
<< flattenedCombinedLoops.size() << " iter args.";
98100
return;
99101
}
100-
collapseParallelLoops(op, combinedLoops);
102+
collapseParallelLoops(rewriter, op, combinedLoops);
101103
});
102104
}
103105
};

0 commit comments

Comments
 (0)