@@ -41,26 +41,31 @@ struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
41
41
return rewriter.notifyMatchFailure (loop,
42
42
" Loop body must have single cmp op" );
43
43
44
- auto beforeTerm = cast<scf::ConditionOp>(beforeBody->getTerminator ());
45
- if (!llvm::hasSingleElement (cmp->getUses ()) &&
46
- beforeTerm.getCondition () == cmp.getResult ())
44
+ scf::ConditionOp beforeTerm = loop.getConditionOp ();
45
+ if (!cmp->hasOneUse () && beforeTerm.getCondition () == cmp.getResult ())
47
46
return rewriter.notifyMatchFailure (loop, [&](Diagnostic &diag) {
48
47
diag << " Expected single condiditon use: " << *cmp;
49
48
});
50
49
50
+ // All `before` block args must be directly forwarded to ConditionOp.
51
+ // They will be converted to `scf.for` `iter_vars` except induction var.
51
52
if (ValueRange (beforeBody->getArguments ()) != beforeTerm.getArgs ())
52
53
return rewriter.notifyMatchFailure (loop, " Invalid args order" );
53
54
54
55
using Pred = arith::CmpIPredicate;
55
- auto predicate = cmp.getPredicate ();
56
+ Pred predicate = cmp.getPredicate ();
56
57
if (predicate != Pred::slt && predicate != Pred::sgt)
57
58
return rewriter.notifyMatchFailure (loop, [&](Diagnostic &diag) {
58
59
diag << " Expected 'slt' or 'sgt' predicate: " << *cmp;
59
60
});
60
61
61
- BlockArgument iterVar ;
62
+ BlockArgument indVar ;
62
63
Value end;
63
64
DominanceInfo dom;
65
+
66
+ // Check if cmp has a suitable form. One of the arguments must be a `before`
67
+ // block arg, other must be defined outside `scf.while` and will be treated
68
+ // as upper bound.
64
69
for (bool reverse : {false , true }) {
65
70
auto expectedPred = reverse ? Pred::sgt : Pred::slt;
66
71
if (cmp.getPredicate () != expectedPred)
@@ -76,36 +81,42 @@ struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
76
81
if (!dom.properlyDominates (arg2, loop))
77
82
continue ;
78
83
79
- iterVar = blockArg;
84
+ indVar = blockArg;
80
85
end = arg2;
81
86
break ;
82
87
}
83
88
84
- if (!iterVar )
89
+ if (!indVar )
85
90
return rewriter.notifyMatchFailure (loop, [&](Diagnostic &diag) {
86
91
diag << " Unrecognized cmp form: " << *cmp;
87
92
});
88
93
89
- if (!llvm::hasNItems (iterVar.getUses (), 2 ))
94
+ // indVar must have 2 uses: one is in `cmp` and other is `condition` arg.
95
+ if (!llvm::hasNItems (indVar.getUses (), 2 ))
90
96
return rewriter.notifyMatchFailure (loop, [&](Diagnostic &diag) {
91
- diag << " Unrecognized iter var: " << iterVar ;
97
+ diag << " Unrecognized induction var: " << indVar ;
92
98
});
93
99
94
100
Block *afterBody = loop.getAfterBody ();
95
- auto afterTerm = cast<scf::YieldOp>(afterBody-> getTerminator () );
96
- auto argNumber = iterVar .getArgNumber ();
101
+ scf::YieldOp afterTerm = loop. getYieldOp ( );
102
+ auto argNumber = indVar .getArgNumber ();
97
103
auto afterTermIterArg = afterTerm.getResults ()[argNumber];
98
104
99
- auto iterVarAfter = afterBody->getArgument (argNumber);
105
+ auto indVarAfter = afterBody->getArgument (argNumber);
100
106
101
107
Value step;
102
- for (auto &use : iterVarAfter.getUses ()) {
108
+
109
+ // Find suitable `addi` op inside `after` block, one of the args must be an
110
+ // Induction var passed from `before` block and second arg must be defined
111
+ // outside of the loop and will be considered step value.
112
+ // TODO: Add `subi` support?
113
+ for (auto &use : indVarAfter.getUses ()) {
103
114
auto owner = dyn_cast<arith::AddIOp>(use.getOwner ());
104
115
if (!owner)
105
116
continue ;
106
117
107
118
auto other =
108
- (iterVarAfter == owner.getLhs () ? owner.getRhs () : owner.getLhs ());
119
+ (indVarAfter == owner.getLhs () ? owner.getRhs () : owner.getLhs ());
109
120
if (!dom.properlyDominates (other, loop))
110
121
continue ;
111
122
@@ -118,7 +129,7 @@ struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
118
129
119
130
if (!step)
120
131
return rewriter.notifyMatchFailure (loop,
121
- " Didn't found suitable 'add ' op" );
132
+ " Didn't found suitable 'addi ' op" );
122
133
123
134
auto begin = loop.getInits ()[argNumber];
124
135
@@ -136,16 +147,12 @@ struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
136
147
}
137
148
138
149
auto loc = loop.getLoc ();
139
- auto emptyBuidler = [](OpBuilder &, Location, Value, ValueRange) {};
150
+ auto emptyBuilder = [](OpBuilder &, Location, Value, ValueRange) {};
140
151
auto newLoop = rewriter.create <scf::ForOp>(loc, begin, end, step, mapping,
141
- emptyBuidler );
152
+ emptyBuilder );
142
153
143
154
Block *newBody = newLoop.getBody ();
144
155
145
- OpBuilder::InsertionGuard g (rewriter);
146
- rewriter.setInsertionPointToStart (newBody);
147
- Value newIterVar = newBody->getArgument (0 );
148
-
149
156
mapping.clear ();
150
157
auto newArgs = newBody->getArguments ();
151
158
for (auto i : llvm::seq<size_t >(0 , newArgs.size ())) {
@@ -171,6 +178,7 @@ struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
171
178
mapping.emplace_back (arg);
172
179
}
173
180
181
+ OpBuilder::InsertionGuard g (rewriter);
174
182
rewriter.setInsertionPoint (term);
175
183
rewriter.replaceOpWithNewOp <scf::YieldOp>(term, mapping);
176
184
0 commit comments