@@ -48,10 +48,46 @@ FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter,
48
48
diag << " Expected single condition use: " << *cmp;
49
49
});
50
50
51
+ // If all 'before' arguments are forwarded but the order is different from
52
+ // 'after' arguments, here is the mapping from the 'after' argument index to
53
+ // the 'before' argument index.
54
+ std::optional<SmallVector<unsigned >> argReorder;
51
55
// All `before` block args must be directly forwarded to ConditionOp.
52
56
// They will be converted to `scf.for` `iter_vars` except induction var.
53
- if (ValueRange (beforeBody->getArguments ()) != beforeTerm.getArgs ())
54
- return rewriter.notifyMatchFailure (loop, " Invalid args order" );
57
+ if (ValueRange (beforeBody->getArguments ()) != beforeTerm.getArgs ()) {
58
+ auto getArgReordering =
59
+ [](Block *beforeBody,
60
+ scf::ConditionOp cond) -> std::optional<SmallVector<unsigned >> {
61
+ // Skip further checking if their sizes mismatch.
62
+ if (beforeBody->getNumArguments () != cond.getArgs ().size ())
63
+ return std::nullopt;
64
+ // Bitset on which 'before' argument is forwarded.
65
+ llvm::SmallBitVector forwarded (beforeBody->getNumArguments (), false );
66
+ // The forwarding order of 'before' arguments.
67
+ SmallVector<unsigned > order;
68
+ for (Value a : cond.getArgs ()) {
69
+ BlockArgument arg = dyn_cast<BlockArgument>(a);
70
+ // Skip if 'arg' is not a 'before' argument.
71
+ if (!arg || arg.getOwner () != beforeBody)
72
+ return std::nullopt;
73
+ unsigned idx = arg.getArgNumber ();
74
+ // Skip if 'arg' is already forwarded in another place.
75
+ if (forwarded[idx])
76
+ return std::nullopt;
77
+ // Record the presence of 'arg' and its order.
78
+ forwarded[idx] = true ;
79
+ order.push_back (idx);
80
+ }
81
+ // Skip if not all 'before' arguments are forwarded.
82
+ if (!forwarded.all ())
83
+ return std::nullopt;
84
+ return order;
85
+ };
86
+ // Check if 'before' arguments are all forwarded but just reordered.
87
+ argReorder = getArgReordering (beforeBody, beforeTerm);
88
+ if (!argReorder)
89
+ return rewriter.notifyMatchFailure (loop, " Invalid args order" );
90
+ }
55
91
56
92
using Pred = arith::CmpIPredicate;
57
93
Pred predicate = cmp.getPredicate ();
@@ -104,7 +140,14 @@ FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter,
104
140
unsigned argNumber = inductionVar.getArgNumber ();
105
141
Value afterTermIndArg = afterTerm.getResults ()[argNumber];
106
142
107
- Value inductionVarAfter = afterBody->getArgument (argNumber);
143
+ auto findAfterArgNo = [](ArrayRef<unsigned > indices, unsigned beforeArgNo) {
144
+ return std::distance (indices.begin (),
145
+ llvm::find_if (indices, [beforeArgNo](unsigned n) {
146
+ return n == beforeArgNo;
147
+ }));
148
+ };
149
+ Value inductionVarAfter = afterBody->getArgument (
150
+ argReorder ? findAfterArgNo (*argReorder, argNumber) : argNumber);
108
151
109
152
// Find suitable `addi` op inside `after` block, one of the args must be an
110
153
// Induction var passed from `before` block and second arg must be defined
@@ -130,7 +173,7 @@ FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter,
130
173
assert (lb.getType () == ub.getType ());
131
174
assert (lb.getType () == step.getType ());
132
175
133
- llvm:: SmallVector<Value> newArgs;
176
+ SmallVector<Value> newArgs;
134
177
135
178
// Populate inits for new `scf.for`, skip induction var.
136
179
newArgs.reserve (loop.getInits ().size ());
@@ -164,6 +207,14 @@ FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter,
164
207
newArgs.emplace_back (newBodyArgs[i]);
165
208
}
166
209
}
210
+ if (argReorder) {
211
+ // Reorder arguments following the 'after' argument order from the original
212
+ // 'while' loop.
213
+ SmallVector<Value> args;
214
+ for (unsigned order : *argReorder)
215
+ args.push_back (newArgs[order]);
216
+ newArgs = args;
217
+ }
167
218
168
219
rewriter.inlineBlockBefore (loop.getAfterBody (), newBody, newBody->end (),
169
220
newArgs);
@@ -205,6 +256,14 @@ FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter,
205
256
newArgs.clear ();
206
257
llvm::append_range (newArgs, newLoop.getResults ());
207
258
newArgs.insert (newArgs.begin () + argNumber, res);
259
+ if (argReorder) {
260
+ // Reorder arguments following the 'after' argument order from the original
261
+ // 'while' loop.
262
+ SmallVector<Value> results;
263
+ for (unsigned order : *argReorder)
264
+ results.push_back (newArgs[order]);
265
+ newArgs = results;
266
+ }
208
267
rewriter.replaceOp (loop, newArgs);
209
268
return newLoop;
210
269
}
0 commit comments