@@ -60,7 +60,7 @@ struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
60
60
});
61
61
62
62
BlockArgument indVar;
63
- Value end ;
63
+ Value ub ;
64
64
DominanceInfo dom;
65
65
66
66
// Check if cmp has a suitable form. One of the arguments must be a `before`
@@ -82,7 +82,7 @@ struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
82
82
continue ;
83
83
84
84
indVar = blockArg;
85
- end = arg2;
85
+ ub = arg2;
86
86
break ;
87
87
}
88
88
@@ -131,57 +131,66 @@ struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
131
131
return rewriter.notifyMatchFailure (loop,
132
132
" Didn't found suitable 'addi' op" );
133
133
134
- auto begin = loop.getInits ()[argNumber];
134
+ auto lb = loop.getInits ()[argNumber];
135
135
136
- assert (begin .getType ().isIntOrIndex ());
137
- assert (begin .getType () == end .getType ());
138
- assert (begin .getType () == step.getType ());
136
+ assert (lb .getType ().isIntOrIndex ());
137
+ assert (lb .getType () == ub .getType ());
138
+ assert (lb .getType () == step.getType ());
139
139
140
- llvm::SmallVector<Value> mapping;
141
- mapping.reserve (loop.getInits ().size ());
140
+ llvm::SmallVector<Value> newArgs;
141
+
142
+ // Populate inits for new `scf.for`
143
+ newArgs.reserve (loop.getInits ().size ());
142
144
for (auto &&[i, init] : llvm::enumerate (loop.getInits ())) {
143
145
if (i == argNumber)
144
146
continue ;
145
147
146
- mapping .emplace_back (init);
148
+ newArgs .emplace_back (init);
147
149
}
148
150
149
151
auto loc = loop.getLoc ();
152
+
153
+ // With `builder == nullptr`, ForOp::build will try to insert terminator at
154
+ // the end of newly created block and we don't want it. Provide empty
155
+ // dummy builder instead.
150
156
auto emptyBuilder = [](OpBuilder &, Location, Value, ValueRange) {};
151
- auto newLoop = rewriter. create <scf::ForOp>(loc, begin, end, step, mapping,
152
- emptyBuilder);
157
+ auto newLoop =
158
+ rewriter. create <scf::ForOp>(loc, lb, ub, step, newArgs, emptyBuilder);
153
159
154
160
Block *newBody = newLoop.getBody ();
155
161
156
- mapping.clear ();
157
- auto newArgs = newBody->getArguments ();
158
- for (auto i : llvm::seq<size_t >(0 , newArgs.size ())) {
162
+ // Populate block args for `scf.for` body, move induction var to the front.
163
+ newArgs.clear ();
164
+ ValueRange newBodyArgs = newBody->getArguments ();
165
+ for (auto i : llvm::seq<size_t >(0 , newBodyArgs.size ())) {
159
166
if (i < argNumber) {
160
- mapping .emplace_back (newArgs [i + 1 ]);
167
+ newArgs .emplace_back (newBodyArgs [i + 1 ]);
161
168
} else if (i == argNumber) {
162
- mapping .emplace_back (newArgs .front ());
169
+ newArgs .emplace_back (newBodyArgs .front ());
163
170
} else {
164
- mapping .emplace_back (newArgs [i]);
171
+ newArgs .emplace_back (newBodyArgs [i]);
165
172
}
166
173
}
167
174
168
175
rewriter.inlineBlockBefore (loop.getAfterBody (), newBody, newBody->end (),
169
- mapping );
176
+ newArgs );
170
177
171
178
auto term = cast<scf::YieldOp>(newBody->getTerminator ());
172
179
173
- mapping.clear ();
180
+ // Populate new yield args, skipping the induction var.
181
+ newArgs.clear ();
174
182
for (auto &&[i, arg] : llvm::enumerate (term.getResults ())) {
175
183
if (i == argNumber)
176
184
continue ;
177
185
178
- mapping .emplace_back (arg);
186
+ newArgs .emplace_back (arg);
179
187
}
180
188
181
189
OpBuilder::InsertionGuard g (rewriter);
182
190
rewriter.setInsertionPoint (term);
183
- rewriter.replaceOpWithNewOp <scf::YieldOp>(term, mapping );
191
+ rewriter.replaceOpWithNewOp <scf::YieldOp>(term, newArgs );
184
192
193
+ // Compute induction var value after loop execution.
185
194
rewriter.setInsertionPointAfter (newLoop);
186
195
Value one;
187
196
if (isa<IndexType>(step.getType ())) {
@@ -191,17 +200,19 @@ struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
191
200
}
192
201
193
202
Value stepDec = rewriter.create <arith::SubIOp>(loc, step, one);
194
- Value len = rewriter.create <arith::SubIOp>(loc, end, begin );
203
+ Value len = rewriter.create <arith::SubIOp>(loc, ub, lb );
195
204
len = rewriter.create <arith::AddIOp>(loc, len, stepDec);
196
205
len = rewriter.create <arith::DivSIOp>(loc, len, step);
197
206
len = rewriter.create <arith::SubIOp>(loc, len, one);
198
207
Value res = rewriter.create <arith::MulIOp>(loc, len, step);
199
- res = rewriter.create <arith::AddIOp>(loc, begin, res);
200
-
201
- mapping.clear ();
202
- llvm::append_range (mapping, newLoop.getResults ());
203
- mapping.insert (mapping.begin () + argNumber, res);
204
- rewriter.replaceOp (loop, mapping);
208
+ res = rewriter.create <arith::AddIOp>(loc, lb, res);
209
+
210
+ // Reconstruct `scf.while` results, inserting final induction var value
211
+ // into proper place.
212
+ newArgs.clear ();
213
+ llvm::append_range (newArgs, newLoop.getResults ());
214
+ newArgs.insert (newArgs.begin () + argNumber, res);
215
+ rewriter.replaceOp (loop, newArgs);
205
216
return success ();
206
217
}
207
218
};
0 commit comments