Skip to content

Commit 3883cc8

Browse files
committed
[mlir][scf] Uplift scf.while to scf.for
Add uplifting from `scf.while` to `scf.for`. This uplifting expects a very specifi ops pattern: * `before` body consisting of single `arith.cmp` op * `after` body containing `arith.addi` * Iter var must be of type `index` or integer of specified width We also have a set of patterns to clenaup `scf.while` loops to get them close to the desired form, they will be added in separate PRs. This is part of upstreaming `numba-mlir` scf uplifting pipeline: `cf -> scf.while -> scf.for -> scf.parallel` Original code: https://github.com/numba/numba-mlir/blob/main/mlir/lib/Transforms/PromoteToParallel.cpp
1 parent 5b66b6a commit 3883cc8

File tree

5 files changed

+433
-0
lines changed

5 files changed

+433
-0
lines changed

mlir/include/mlir/Dialect/SCF/Transforms/Passes.td

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,4 +154,20 @@ def SCFForToWhileLoop : Pass<"scf-for-to-while"> {
154154
}];
155155
}
156156

157+
def SCFUpliftWhileToFor : Pass<"scf-uplift-while-to-for"> {
158+
let summary = "Uplift scf.while ops to scf.for";
159+
let description = [{
160+
This pass tries to uplift `scf.while` ops to `scf.for` if they have a
161+
compatible form. `scf.while` are left unchanged if uplifting is not
162+
possible.
163+
}];
164+
165+
let options = [
166+
Option<"indexBitWidth", "index-bitwidth", "unsigned",
167+
/*default=*/"64",
168+
"Bitwidth of index type.">,
169+
];
170+
}
171+
172+
157173
#endif // MLIR_DIALECT_SCF_PASSES

mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,10 @@ void populateSCFLoopPipeliningPatterns(RewritePatternSet &patterns,
7979
/// loop bounds and loop steps are canonicalized.
8080
void populateSCFForLoopCanonicalizationPatterns(RewritePatternSet &patterns);
8181

82+
/// Populate patterns to uplift `scf.while` ops to `scf.for`.
83+
void populateUpliftWhileToForPatterns(RewritePatternSet &patterns,
84+
unsigned indexBitwidth);
85+
8286
} // namespace scf
8387
} // namespace mlir
8488

mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ add_mlir_dialect_library(MLIRSCFTransforms
1414
StructuralTypeConversions.cpp
1515
TileUsingInterface.cpp
1616
WrapInZeroTripCheck.cpp
17+
UpliftWhileToFor.cpp
1718

1819
ADDITIONAL_HEADER_DIRS
1920
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SCF
Lines changed: 250 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,250 @@
1+
//===- UpliftWhileToFor.cpp - scf.while to scf.for loop uplifting ---------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Transforms SCF.WhileOp's into SCF.ForOp's.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "mlir/Dialect/SCF/Transforms/Passes.h"
14+
15+
#include "mlir/Dialect/Arith/IR/Arith.h"
16+
#include "mlir/Dialect/SCF/IR/SCF.h"
17+
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
18+
#include "mlir/IR/Dominance.h"
19+
#include "mlir/IR/PatternMatch.h"
20+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
21+
22+
namespace mlir {
23+
#define GEN_PASS_DEF_SCFUPLIFTWHILETOFOR
24+
#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
25+
} // namespace mlir
26+
27+
using namespace mlir;
28+
29+
static bool checkIndexType(arith::CmpIOp op, unsigned indexBitWidth) {
30+
auto type = op.getLhs().getType();
31+
if (isa<mlir::IndexType>(type))
32+
return true;
33+
34+
if (type.isSignlessInteger(indexBitWidth))
35+
return true;
36+
37+
return false;
38+
}
39+
40+
namespace {
41+
struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
42+
UpliftWhileOp(MLIRContext *context, unsigned indexBitWidth_)
43+
: OpRewritePattern<scf::WhileOp>(context), indexBitWidth(indexBitWidth_) {
44+
}
45+
46+
LogicalResult matchAndRewrite(scf::WhileOp loop,
47+
PatternRewriter &rewriter) const override {
48+
Block *beforeBody = loop.getBeforeBody();
49+
if (!llvm::hasSingleElement(beforeBody->without_terminator()))
50+
return rewriter.notifyMatchFailure(loop, "Loop body must have single op");
51+
52+
auto cmp = dyn_cast<arith::CmpIOp>(beforeBody->front());
53+
if (!cmp)
54+
return rewriter.notifyMatchFailure(loop,
55+
"Loop body must have single cmp op");
56+
57+
auto beforeTerm = cast<scf::ConditionOp>(beforeBody->getTerminator());
58+
if (!llvm::hasSingleElement(cmp->getUses()) &&
59+
beforeTerm.getCondition() == cmp.getResult())
60+
return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
61+
diag << "Expected single condiditon use: " << *cmp;
62+
});
63+
64+
if (ValueRange(beforeBody->getArguments()) != beforeTerm.getArgs())
65+
return rewriter.notifyMatchFailure(loop, "Invalid args order");
66+
67+
using Pred = arith::CmpIPredicate;
68+
auto predicate = cmp.getPredicate();
69+
if (predicate != Pred::slt && predicate != Pred::sgt)
70+
return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
71+
diag << "Expected 'slt' or 'sgt' predicate: " << *cmp;
72+
});
73+
74+
if (!checkIndexType(cmp, indexBitWidth))
75+
return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
76+
diag << "Expected index-like type: " << *cmp;
77+
});
78+
79+
BlockArgument iterVar;
80+
Value end;
81+
DominanceInfo dom;
82+
for (bool reverse : {false, true}) {
83+
auto expectedPred = reverse ? Pred::sgt : Pred::slt;
84+
if (cmp.getPredicate() != expectedPred)
85+
continue;
86+
87+
auto arg1 = reverse ? cmp.getRhs() : cmp.getLhs();
88+
auto arg2 = reverse ? cmp.getLhs() : cmp.getRhs();
89+
90+
auto blockArg = dyn_cast<BlockArgument>(arg1);
91+
if (!blockArg || blockArg.getOwner() != beforeBody)
92+
continue;
93+
94+
if (!dom.properlyDominates(arg2, loop))
95+
continue;
96+
97+
iterVar = blockArg;
98+
end = arg2;
99+
break;
100+
}
101+
102+
if (!iterVar)
103+
return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
104+
diag << "Unrecognized cmp form: " << *cmp;
105+
});
106+
107+
if (!llvm::hasNItems(iterVar.getUses(), 2))
108+
return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
109+
diag << "Unrecognized iter var: " << iterVar;
110+
});
111+
112+
Block *afterBody = loop.getAfterBody();
113+
auto afterTerm = cast<scf::YieldOp>(afterBody->getTerminator());
114+
auto argNumber = iterVar.getArgNumber();
115+
auto afterTermIterArg = afterTerm.getResults()[argNumber];
116+
117+
auto iterVarAfter = afterBody->getArgument(argNumber);
118+
119+
Value step;
120+
for (auto &use : iterVarAfter.getUses()) {
121+
auto owner = dyn_cast<arith::AddIOp>(use.getOwner());
122+
if (!owner)
123+
continue;
124+
125+
auto other =
126+
(iterVarAfter == owner.getLhs() ? owner.getRhs() : owner.getLhs());
127+
if (!dom.properlyDominates(other, loop))
128+
continue;
129+
130+
if (afterTermIterArg != owner.getResult())
131+
continue;
132+
133+
step = other;
134+
break;
135+
}
136+
137+
if (!step)
138+
return rewriter.notifyMatchFailure(loop,
139+
"Didn't found suitable 'add' op");
140+
141+
auto begin = loop.getInits()[argNumber];
142+
143+
auto loc = loop.getLoc();
144+
auto indexType = rewriter.getIndexType();
145+
auto toIndex = [&](Value val) -> Value {
146+
if (val.getType() != indexType)
147+
return rewriter.create<arith::IndexCastOp>(loc, indexType, val);
148+
149+
return val;
150+
};
151+
begin = toIndex(begin);
152+
end = toIndex(end);
153+
step = toIndex(step);
154+
155+
llvm::SmallVector<Value> mapping;
156+
mapping.reserve(loop.getInits().size());
157+
for (auto &&[i, init] : llvm::enumerate(loop.getInits())) {
158+
if (i == argNumber)
159+
continue;
160+
161+
mapping.emplace_back(init);
162+
}
163+
164+
auto emptyBuidler = [](OpBuilder &, Location, Value, ValueRange) {};
165+
auto newLoop = rewriter.create<scf::ForOp>(loc, begin, end, step, mapping,
166+
emptyBuidler);
167+
168+
Block *newBody = newLoop.getBody();
169+
170+
OpBuilder::InsertionGuard g(rewriter);
171+
rewriter.setInsertionPointToStart(newBody);
172+
Value newIterVar = newBody->getArgument(0);
173+
if (newIterVar.getType() != iterVar.getType())
174+
newIterVar = rewriter.create<arith::IndexCastOp>(loc, iterVar.getType(),
175+
newIterVar);
176+
177+
mapping.clear();
178+
auto newArgs = newBody->getArguments();
179+
for (auto i : llvm::seq<size_t>(0, newArgs.size())) {
180+
if (i < argNumber) {
181+
mapping.emplace_back(newArgs[i + 1]);
182+
} else if (i == argNumber) {
183+
Value arg = newArgs.front();
184+
if (arg.getType() != iterVar.getType())
185+
arg =
186+
rewriter.create<arith::IndexCastOp>(loc, iterVar.getType(), arg);
187+
mapping.emplace_back(arg);
188+
} else {
189+
mapping.emplace_back(newArgs[i]);
190+
}
191+
}
192+
193+
rewriter.inlineBlockBefore(loop.getAfterBody(), newBody, newBody->end(),
194+
mapping);
195+
196+
auto term = cast<scf::YieldOp>(newBody->getTerminator());
197+
198+
mapping.clear();
199+
for (auto &&[i, arg] : llvm::enumerate(term.getResults())) {
200+
if (i == argNumber)
201+
continue;
202+
203+
mapping.emplace_back(arg);
204+
}
205+
206+
rewriter.setInsertionPoint(term);
207+
rewriter.replaceOpWithNewOp<scf::YieldOp>(term, mapping);
208+
209+
rewriter.setInsertionPointAfter(newLoop);
210+
Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
211+
Value stepDec = rewriter.create<arith::SubIOp>(loc, step, one);
212+
Value len = rewriter.create<arith::SubIOp>(loc, end, begin);
213+
len = rewriter.create<arith::AddIOp>(loc, len, stepDec);
214+
len = rewriter.create<arith::DivSIOp>(loc, len, step);
215+
len = rewriter.create<arith::SubIOp>(loc, len, one);
216+
Value res = rewriter.create<arith::MulIOp>(loc, len, step);
217+
res = rewriter.create<arith::AddIOp>(loc, begin, res);
218+
if (res.getType() != iterVar.getType())
219+
res = rewriter.create<arith::IndexCastOp>(loc, iterVar.getType(), res);
220+
221+
mapping.clear();
222+
llvm::append_range(mapping, newLoop.getResults());
223+
mapping.insert(mapping.begin() + argNumber, res);
224+
rewriter.replaceOp(loop, mapping);
225+
return success();
226+
}
227+
228+
private:
229+
unsigned indexBitWidth = 0;
230+
};
231+
232+
struct SCFUpliftWhileToFor final
233+
: impl::SCFUpliftWhileToForBase<SCFUpliftWhileToFor> {
234+
using SCFUpliftWhileToForBase::SCFUpliftWhileToForBase;
235+
236+
void runOnOperation() override {
237+
Operation *op = getOperation();
238+
MLIRContext *ctx = op->getContext();
239+
RewritePatternSet patterns(ctx);
240+
mlir::scf::populateUpliftWhileToForPatterns(patterns, this->indexBitWidth);
241+
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
242+
signalPassFailure();
243+
}
244+
};
245+
} // namespace
246+
247+
void mlir::scf::populateUpliftWhileToForPatterns(RewritePatternSet &patterns,
248+
unsigned indexBitwidth) {
249+
patterns.add<UpliftWhileOp>(patterns.getContext(), indexBitwidth);
250+
}

0 commit comments

Comments
 (0)