Skip to content

Commit b153c05

Browse files
authored
[mlir][scf] Uplift scf.while to scf.for (#76108)
Add uplifting from `scf.while` to `scf.for`. This uplifting expects a very specific ops pattern: * `before` block consisting of single `arith.cmp` op * `after` block containing `arith.addi` We also have a set of patterns to cleanup `scf.while` loops to get them close to the desired form, they will be added in separate PRs. This is part of upstreaming `numba-mlir` scf uplifting pipeline: `cf -> scf.while -> scf.for -> scf.parallel` Original code: https://github.com/numba/numba-mlir/blob/main/mlir/lib/Transforms/PromoteToParallel.cpp
1 parent c31a810 commit b153c05

File tree

8 files changed

+445
-0
lines changed

8 files changed

+445
-0
lines changed

mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,12 @@ void populateSCFLoopPipeliningPatterns(RewritePatternSet &patterns,
7979
/// loop bounds and loop steps are canonicalized.
8080
void populateSCFForLoopCanonicalizationPatterns(RewritePatternSet &patterns);
8181

82+
/// Populate patterns to uplift `scf.while` ops to `scf.for`.
83+
/// Uplifitng expects a specific ops pattern:
84+
/// * `before` block consisting of single arith.cmp op
85+
/// * `after` block containing arith.addi
86+
void populateUpliftWhileToForPatterns(RewritePatternSet &patterns);
87+
8288
} // namespace scf
8389
} // namespace mlir
8490

mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,12 @@ FailureOr<WhileOp> wrapWhileLoopInZeroTripCheck(WhileOp whileOp,
222222
RewriterBase &rewriter,
223223
bool forceCreateCheck = false);
224224

225+
/// Try to uplift `scf.while` op to `scf.for`.
226+
/// Uplifitng expects a specific ops pattern:
227+
/// * `before` block consisting of single arith.cmp op
228+
/// * `after` block containing arith.addi
229+
FailureOr<ForOp> upliftWhileToForLoop(RewriterBase &rewriter, WhileOp loop);
230+
225231
} // namespace scf
226232
} // namespace mlir
227233

mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ add_mlir_dialect_library(MLIRSCFTransforms
1414
StructuralTypeConversions.cpp
1515
TileUsingInterface.cpp
1616
WrapInZeroTripCheck.cpp
17+
UpliftWhileToFor.cpp
1718

1819
ADDITIONAL_HEADER_DIRS
1920
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SCF
Lines changed: 222 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,222 @@
1+
//===- UpliftWhileToFor.cpp - scf.while to scf.for loop uplifting ---------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Transforms SCF.WhileOp's into SCF.ForOp's.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "mlir/Dialect/SCF/Transforms/Passes.h"
14+
15+
#include "mlir/Dialect/Arith/IR/Arith.h"
16+
#include "mlir/Dialect/SCF/IR/SCF.h"
17+
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
18+
#include "mlir/IR/Dominance.h"
19+
#include "mlir/IR/PatternMatch.h"
20+
21+
using namespace mlir;
22+
23+
namespace {
24+
struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
25+
using OpRewritePattern::OpRewritePattern;
26+
27+
LogicalResult matchAndRewrite(scf::WhileOp loop,
28+
PatternRewriter &rewriter) const override {
29+
return upliftWhileToForLoop(rewriter, loop);
30+
}
31+
};
32+
} // namespace
33+
34+
FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter,
35+
scf::WhileOp loop) {
36+
Block *beforeBody = loop.getBeforeBody();
37+
if (!llvm::hasSingleElement(beforeBody->without_terminator()))
38+
return rewriter.notifyMatchFailure(loop, "Loop body must have single op");
39+
40+
auto cmp = dyn_cast<arith::CmpIOp>(beforeBody->front());
41+
if (!cmp)
42+
return rewriter.notifyMatchFailure(loop,
43+
"Loop body must have single cmp op");
44+
45+
scf::ConditionOp beforeTerm = loop.getConditionOp();
46+
if (!cmp->hasOneUse() || beforeTerm.getCondition() != cmp.getResult())
47+
return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
48+
diag << "Expected single condition use: " << *cmp;
49+
});
50+
51+
// All `before` block args must be directly forwarded to ConditionOp.
52+
// They will be converted to `scf.for` `iter_vars` except induction var.
53+
if (ValueRange(beforeBody->getArguments()) != beforeTerm.getArgs())
54+
return rewriter.notifyMatchFailure(loop, "Invalid args order");
55+
56+
using Pred = arith::CmpIPredicate;
57+
Pred predicate = cmp.getPredicate();
58+
if (predicate != Pred::slt && predicate != Pred::sgt)
59+
return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
60+
diag << "Expected 'slt' or 'sgt' predicate: " << *cmp;
61+
});
62+
63+
BlockArgument inductionVar;
64+
Value ub;
65+
DominanceInfo dom;
66+
67+
// Check if cmp has a suitable form. One of the arguments must be a `before`
68+
// block arg, other must be defined outside `scf.while` and will be treated
69+
// as upper bound.
70+
for (bool reverse : {false, true}) {
71+
auto expectedPred = reverse ? Pred::sgt : Pred::slt;
72+
if (cmp.getPredicate() != expectedPred)
73+
continue;
74+
75+
auto arg1 = reverse ? cmp.getRhs() : cmp.getLhs();
76+
auto arg2 = reverse ? cmp.getLhs() : cmp.getRhs();
77+
78+
auto blockArg = dyn_cast<BlockArgument>(arg1);
79+
if (!blockArg || blockArg.getOwner() != beforeBody)
80+
continue;
81+
82+
if (!dom.properlyDominates(arg2, loop))
83+
continue;
84+
85+
inductionVar = blockArg;
86+
ub = arg2;
87+
break;
88+
}
89+
90+
if (!inductionVar)
91+
return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
92+
diag << "Unrecognized cmp form: " << *cmp;
93+
});
94+
95+
// inductionVar must have 2 uses: one is in `cmp` and other is `condition`
96+
// arg.
97+
if (!llvm::hasNItems(inductionVar.getUses(), 2))
98+
return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
99+
diag << "Unrecognized induction var: " << inductionVar;
100+
});
101+
102+
Block *afterBody = loop.getAfterBody();
103+
scf::YieldOp afterTerm = loop.getYieldOp();
104+
auto argNumber = inductionVar.getArgNumber();
105+
auto afterTermIndArg = afterTerm.getResults()[argNumber];
106+
107+
auto inductionVarAfter = afterBody->getArgument(argNumber);
108+
109+
Value step;
110+
111+
// Find suitable `addi` op inside `after` block, one of the args must be an
112+
// Induction var passed from `before` block and second arg must be defined
113+
// outside of the loop and will be considered step value.
114+
// TODO: Add `subi` support?
115+
for (auto &use : inductionVarAfter.getUses()) {
116+
auto owner = dyn_cast<arith::AddIOp>(use.getOwner());
117+
if (!owner)
118+
continue;
119+
120+
auto other =
121+
(inductionVarAfter == owner.getLhs() ? owner.getRhs() : owner.getLhs());
122+
if (!dom.properlyDominates(other, loop))
123+
continue;
124+
125+
if (afterTermIndArg != owner.getResult())
126+
continue;
127+
128+
step = other;
129+
break;
130+
}
131+
132+
if (!step)
133+
return rewriter.notifyMatchFailure(loop, "Didn't found suitable 'addi' op");
134+
135+
auto lb = loop.getInits()[argNumber];
136+
137+
assert(lb.getType().isIntOrIndex());
138+
assert(lb.getType() == ub.getType());
139+
assert(lb.getType() == step.getType());
140+
141+
llvm::SmallVector<Value> newArgs;
142+
143+
// Populate inits for new `scf.for`, skip induction var.
144+
newArgs.reserve(loop.getInits().size());
145+
for (auto &&[i, init] : llvm::enumerate(loop.getInits())) {
146+
if (i == argNumber)
147+
continue;
148+
149+
newArgs.emplace_back(init);
150+
}
151+
152+
Location loc = loop.getLoc();
153+
154+
// With `builder == nullptr`, ForOp::build will try to insert terminator at
155+
// the end of newly created block and we don't want it. Provide empty
156+
// dummy builder instead.
157+
auto emptyBuilder = [](OpBuilder &, Location, Value, ValueRange) {};
158+
auto newLoop =
159+
rewriter.create<scf::ForOp>(loc, lb, ub, step, newArgs, emptyBuilder);
160+
161+
Block *newBody = newLoop.getBody();
162+
163+
// Populate block args for `scf.for` body, move induction var to the front.
164+
newArgs.clear();
165+
ValueRange newBodyArgs = newBody->getArguments();
166+
for (auto i : llvm::seq<size_t>(0, newBodyArgs.size())) {
167+
if (i < argNumber) {
168+
newArgs.emplace_back(newBodyArgs[i + 1]);
169+
} else if (i == argNumber) {
170+
newArgs.emplace_back(newBodyArgs.front());
171+
} else {
172+
newArgs.emplace_back(newBodyArgs[i]);
173+
}
174+
}
175+
176+
rewriter.inlineBlockBefore(loop.getAfterBody(), newBody, newBody->end(),
177+
newArgs);
178+
179+
auto term = cast<scf::YieldOp>(newBody->getTerminator());
180+
181+
// Populate new yield args, skipping the induction var.
182+
newArgs.clear();
183+
for (auto &&[i, arg] : llvm::enumerate(term.getResults())) {
184+
if (i == argNumber)
185+
continue;
186+
187+
newArgs.emplace_back(arg);
188+
}
189+
190+
OpBuilder::InsertionGuard g(rewriter);
191+
rewriter.setInsertionPoint(term);
192+
rewriter.replaceOpWithNewOp<scf::YieldOp>(term, newArgs);
193+
194+
// Compute induction var value after loop execution.
195+
rewriter.setInsertionPointAfter(newLoop);
196+
Value one;
197+
if (isa<IndexType>(step.getType())) {
198+
one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
199+
} else {
200+
one = rewriter.create<arith::ConstantIntOp>(loc, 1, step.getType());
201+
}
202+
203+
Value stepDec = rewriter.create<arith::SubIOp>(loc, step, one);
204+
Value len = rewriter.create<arith::SubIOp>(loc, ub, lb);
205+
len = rewriter.create<arith::AddIOp>(loc, len, stepDec);
206+
len = rewriter.create<arith::DivSIOp>(loc, len, step);
207+
len = rewriter.create<arith::SubIOp>(loc, len, one);
208+
Value res = rewriter.create<arith::MulIOp>(loc, len, step);
209+
res = rewriter.create<arith::AddIOp>(loc, lb, res);
210+
211+
// Reconstruct `scf.while` results, inserting final induction var value
212+
// into proper place.
213+
newArgs.clear();
214+
llvm::append_range(newArgs, newLoop.getResults());
215+
newArgs.insert(newArgs.begin() + argNumber, res);
216+
rewriter.replaceOp(loop, newArgs);
217+
return newLoop;
218+
}
219+
220+
void mlir::scf::populateUpliftWhileToForPatterns(RewritePatternSet &patterns) {
221+
patterns.add<UpliftWhileOp>(patterns.getContext());
222+
}

0 commit comments

Comments
 (0)