Skip to content

Commit d937b98

Browse files
committed
Renamings and comments
1 parent fee42a6 commit d937b98

File tree

1 file changed

+39
-28
lines changed

1 file changed

+39
-28
lines changed

mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp

Lines changed: 39 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
6060
});
6161

6262
BlockArgument indVar;
63-
Value end;
63+
Value ub;
6464
DominanceInfo dom;
6565

6666
// Check if cmp has a suitable form. One of the arguments must be a `before`
@@ -82,7 +82,7 @@ struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
8282
continue;
8383

8484
indVar = blockArg;
85-
end = arg2;
85+
ub = arg2;
8686
break;
8787
}
8888

@@ -131,57 +131,66 @@ struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
131131
return rewriter.notifyMatchFailure(loop,
132132
"Didn't found suitable 'addi' op");
133133

134-
auto begin = loop.getInits()[argNumber];
134+
auto lb = loop.getInits()[argNumber];
135135

136-
assert(begin.getType().isIntOrIndex());
137-
assert(begin.getType() == end.getType());
138-
assert(begin.getType() == step.getType());
136+
assert(lb.getType().isIntOrIndex());
137+
assert(lb.getType() == ub.getType());
138+
assert(lb.getType() == step.getType());
139139

140-
llvm::SmallVector<Value> mapping;
141-
mapping.reserve(loop.getInits().size());
140+
llvm::SmallVector<Value> newArgs;
141+
142+
// Populate inits for new `scf.for`
143+
newArgs.reserve(loop.getInits().size());
142144
for (auto &&[i, init] : llvm::enumerate(loop.getInits())) {
143145
if (i == argNumber)
144146
continue;
145147

146-
mapping.emplace_back(init);
148+
newArgs.emplace_back(init);
147149
}
148150

149151
auto loc = loop.getLoc();
152+
153+
// With `builder == nullptr`, ForOp::build will try to insert terminator at
154+
// the end of newly created block and we don't want it. Provide empty
155+
// dummy builder instead.
150156
auto emptyBuilder = [](OpBuilder &, Location, Value, ValueRange) {};
151-
auto newLoop = rewriter.create<scf::ForOp>(loc, begin, end, step, mapping,
152-
emptyBuilder);
157+
auto newLoop =
158+
rewriter.create<scf::ForOp>(loc, lb, ub, step, newArgs, emptyBuilder);
153159

154160
Block *newBody = newLoop.getBody();
155161

156-
mapping.clear();
157-
auto newArgs = newBody->getArguments();
158-
for (auto i : llvm::seq<size_t>(0, newArgs.size())) {
162+
// Populate block args for `scf.for` body, move induction var to the front.
163+
newArgs.clear();
164+
ValueRange newBodyArgs = newBody->getArguments();
165+
for (auto i : llvm::seq<size_t>(0, newBodyArgs.size())) {
159166
if (i < argNumber) {
160-
mapping.emplace_back(newArgs[i + 1]);
167+
newArgs.emplace_back(newBodyArgs[i + 1]);
161168
} else if (i == argNumber) {
162-
mapping.emplace_back(newArgs.front());
169+
newArgs.emplace_back(newBodyArgs.front());
163170
} else {
164-
mapping.emplace_back(newArgs[i]);
171+
newArgs.emplace_back(newBodyArgs[i]);
165172
}
166173
}
167174

168175
rewriter.inlineBlockBefore(loop.getAfterBody(), newBody, newBody->end(),
169-
mapping);
176+
newArgs);
170177

171178
auto term = cast<scf::YieldOp>(newBody->getTerminator());
172179

173-
mapping.clear();
180+
// Populate new yield args, skipping the induction var.
181+
newArgs.clear();
174182
for (auto &&[i, arg] : llvm::enumerate(term.getResults())) {
175183
if (i == argNumber)
176184
continue;
177185

178-
mapping.emplace_back(arg);
186+
newArgs.emplace_back(arg);
179187
}
180188

181189
OpBuilder::InsertionGuard g(rewriter);
182190
rewriter.setInsertionPoint(term);
183-
rewriter.replaceOpWithNewOp<scf::YieldOp>(term, mapping);
191+
rewriter.replaceOpWithNewOp<scf::YieldOp>(term, newArgs);
184192

193+
// Compute induction var value after loop execution.
185194
rewriter.setInsertionPointAfter(newLoop);
186195
Value one;
187196
if (isa<IndexType>(step.getType())) {
@@ -191,17 +200,19 @@ struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
191200
}
192201

193202
Value stepDec = rewriter.create<arith::SubIOp>(loc, step, one);
194-
Value len = rewriter.create<arith::SubIOp>(loc, end, begin);
203+
Value len = rewriter.create<arith::SubIOp>(loc, ub, lb);
195204
len = rewriter.create<arith::AddIOp>(loc, len, stepDec);
196205
len = rewriter.create<arith::DivSIOp>(loc, len, step);
197206
len = rewriter.create<arith::SubIOp>(loc, len, one);
198207
Value res = rewriter.create<arith::MulIOp>(loc, len, step);
199-
res = rewriter.create<arith::AddIOp>(loc, begin, res);
200-
201-
mapping.clear();
202-
llvm::append_range(mapping, newLoop.getResults());
203-
mapping.insert(mapping.begin() + argNumber, res);
204-
rewriter.replaceOp(loop, mapping);
208+
res = rewriter.create<arith::AddIOp>(loc, lb, res);
209+
210+
// Reconstruct `scf.while` results, inserting final induction var value
211+
// into proper place.
212+
newArgs.clear();
213+
llvm::append_range(newArgs, newLoop.getResults());
214+
newArgs.insert(newArgs.begin() + argNumber, res);
215+
rewriter.replaceOp(loop, newArgs);
205216
return success();
206217
}
207218
};

0 commit comments

Comments
 (0)