Skip to content

Commit fee42a6

Browse files
committed
cleanup
1 parent 08da5b8 commit fee42a6

File tree

2 files changed

+33
-21
lines changed

2 files changed

+33
-21
lines changed

mlir/include/mlir/Dialect/SCF/Transforms/Passes.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,10 @@ def SCFUpliftWhileToFor : Pass<"scf-uplift-while-to-for"> {
160160
This pass tries to uplift `scf.while` ops to `scf.for` if they have a
161161
compatible form. `scf.while` are left unchanged if uplifting is not
162162
possible.
163+
164+
This pass expects a specific ops pattern:
165+
* `before` block consisting of single arith.cmp op
166+
* `after` block containing arith.addi
163167
}];
164168
}
165169

mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp

Lines changed: 29 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -41,26 +41,31 @@ struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
4141
return rewriter.notifyMatchFailure(loop,
4242
"Loop body must have single cmp op");
4343

44-
auto beforeTerm = cast<scf::ConditionOp>(beforeBody->getTerminator());
45-
if (!llvm::hasSingleElement(cmp->getUses()) &&
46-
beforeTerm.getCondition() == cmp.getResult())
44+
scf::ConditionOp beforeTerm = loop.getConditionOp();
45+
if (!cmp->hasOneUse() && beforeTerm.getCondition() == cmp.getResult())
4746
return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
4847
diag << "Expected single condiditon use: " << *cmp;
4948
});
5049

50+
// All `before` block args must be directly forwarded to ConditionOp.
51+
// They will be converted to `scf.for` `iter_vars` except induction var.
5152
if (ValueRange(beforeBody->getArguments()) != beforeTerm.getArgs())
5253
return rewriter.notifyMatchFailure(loop, "Invalid args order");
5354

5455
using Pred = arith::CmpIPredicate;
55-
auto predicate = cmp.getPredicate();
56+
Pred predicate = cmp.getPredicate();
5657
if (predicate != Pred::slt && predicate != Pred::sgt)
5758
return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
5859
diag << "Expected 'slt' or 'sgt' predicate: " << *cmp;
5960
});
6061

61-
BlockArgument iterVar;
62+
BlockArgument indVar;
6263
Value end;
6364
DominanceInfo dom;
65+
66+
// Check if cmp has a suitable form. One of the arguments must be a `before`
67+
// block arg, other must be defined outside `scf.while` and will be treated
68+
// as upper bound.
6469
for (bool reverse : {false, true}) {
6570
auto expectedPred = reverse ? Pred::sgt : Pred::slt;
6671
if (cmp.getPredicate() != expectedPred)
@@ -76,36 +81,42 @@ struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
7681
if (!dom.properlyDominates(arg2, loop))
7782
continue;
7883

79-
iterVar = blockArg;
84+
indVar = blockArg;
8085
end = arg2;
8186
break;
8287
}
8388

84-
if (!iterVar)
89+
if (!indVar)
8590
return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
8691
diag << "Unrecognized cmp form: " << *cmp;
8792
});
8893

89-
if (!llvm::hasNItems(iterVar.getUses(), 2))
94+
// indVar must have 2 uses: one is in `cmp` and other is `condition` arg.
95+
if (!llvm::hasNItems(indVar.getUses(), 2))
9096
return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
91-
diag << "Unrecognized iter var: " << iterVar;
97+
diag << "Unrecognized induction var: " << indVar;
9298
});
9399

94100
Block *afterBody = loop.getAfterBody();
95-
auto afterTerm = cast<scf::YieldOp>(afterBody->getTerminator());
96-
auto argNumber = iterVar.getArgNumber();
101+
scf::YieldOp afterTerm = loop.getYieldOp();
102+
auto argNumber = indVar.getArgNumber();
97103
auto afterTermIterArg = afterTerm.getResults()[argNumber];
98104

99-
auto iterVarAfter = afterBody->getArgument(argNumber);
105+
auto indVarAfter = afterBody->getArgument(argNumber);
100106

101107
Value step;
102-
for (auto &use : iterVarAfter.getUses()) {
108+
109+
// Find suitable `addi` op inside `after` block, one of the args must be an
110+
// Induction var passed from `before` block and second arg must be defined
111+
// outside of the loop and will be considered step value.
112+
// TODO: Add `subi` support?
113+
for (auto &use : indVarAfter.getUses()) {
103114
auto owner = dyn_cast<arith::AddIOp>(use.getOwner());
104115
if (!owner)
105116
continue;
106117

107118
auto other =
108-
(iterVarAfter == owner.getLhs() ? owner.getRhs() : owner.getLhs());
119+
(indVarAfter == owner.getLhs() ? owner.getRhs() : owner.getLhs());
109120
if (!dom.properlyDominates(other, loop))
110121
continue;
111122

@@ -118,7 +129,7 @@ struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
118129

119130
if (!step)
120131
return rewriter.notifyMatchFailure(loop,
121-
"Didn't found suitable 'add' op");
132+
"Didn't found suitable 'addi' op");
122133

123134
auto begin = loop.getInits()[argNumber];
124135

@@ -136,16 +147,12 @@ struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
136147
}
137148

138149
auto loc = loop.getLoc();
139-
auto emptyBuidler = [](OpBuilder &, Location, Value, ValueRange) {};
150+
auto emptyBuilder = [](OpBuilder &, Location, Value, ValueRange) {};
140151
auto newLoop = rewriter.create<scf::ForOp>(loc, begin, end, step, mapping,
141-
emptyBuidler);
152+
emptyBuilder);
142153

143154
Block *newBody = newLoop.getBody();
144155

145-
OpBuilder::InsertionGuard g(rewriter);
146-
rewriter.setInsertionPointToStart(newBody);
147-
Value newIterVar = newBody->getArgument(0);
148-
149156
mapping.clear();
150157
auto newArgs = newBody->getArguments();
151158
for (auto i : llvm::seq<size_t>(0, newArgs.size())) {
@@ -171,6 +178,7 @@ struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
171178
mapping.emplace_back(arg);
172179
}
173180

181+
OpBuilder::InsertionGuard g(rewriter);
174182
rewriter.setInsertionPoint(term);
175183
rewriter.replaceOpWithNewOp<scf::YieldOp>(term, mapping);
176184

0 commit comments

Comments
 (0)