Skip to content

Commit 52975d5

Browse files
committed
[mlir][scf] Allow different forwarding ordering in uplift
- Allow 'before' arguments are forwarded in different order to 'after' body when uplifting `scf.while` to `scf.for`.
1 parent 10f983a commit 52975d5

File tree

2 files changed

+93
-4
lines changed

2 files changed

+93
-4
lines changed

mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp

Lines changed: 63 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,46 @@ FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter,
4848
diag << "Expected single condition use: " << *cmp;
4949
});
5050

51+
// If all 'before' arguments are forwarded but the order is different from
52+
// 'after' arguments, here is the mapping from the 'after' argument index to
53+
// the 'before' argument index.
54+
std::optional<SmallVector<unsigned>> argReorder;
5155
// All `before` block args must be directly forwarded to ConditionOp.
5256
// They will be converted to `scf.for` `iter_vars` except induction var.
53-
if (ValueRange(beforeBody->getArguments()) != beforeTerm.getArgs())
54-
return rewriter.notifyMatchFailure(loop, "Invalid args order");
57+
if (ValueRange(beforeBody->getArguments()) != beforeTerm.getArgs()) {
58+
auto getArgReordering =
59+
[](Block *beforeBody,
60+
scf::ConditionOp cond) -> std::optional<SmallVector<unsigned>> {
61+
// Skip further checking if their sizes mismatch.
62+
if (beforeBody->getNumArguments() != cond.getArgs().size())
63+
return std::nullopt;
64+
// Bitset on which 'before' argument is forwarded.
65+
llvm::SmallBitVector forwarded(beforeBody->getNumArguments(), false);
66+
// The forwarding order of 'before' arguments.
67+
SmallVector<unsigned> order;
68+
for (Value a : cond.getArgs()) {
69+
BlockArgument arg = dyn_cast<BlockArgument>(a);
70+
// Skip if 'arg' is not a 'before' argument.
71+
if (!arg || arg.getOwner() != beforeBody)
72+
return std::nullopt;
73+
unsigned idx = arg.getArgNumber();
74+
// Skip if 'arg' is already forwarded in another place.
75+
if (forwarded[idx])
76+
return std::nullopt;
77+
// Record the presence of 'arg' and its order.
78+
forwarded[idx] = true;
79+
order.push_back(idx);
80+
}
81+
// Skip if not all 'before' arguments are forwarded.
82+
if (!forwarded.all())
83+
return std::nullopt;
84+
return order;
85+
};
86+
// Check if 'before' arguments are all forwarded but just reordered.
87+
argReorder = getArgReordering(beforeBody, beforeTerm);
88+
if (!argReorder)
89+
return rewriter.notifyMatchFailure(loop, "Invalid args order");
90+
}
5591

5692
using Pred = arith::CmpIPredicate;
5793
Pred predicate = cmp.getPredicate();
@@ -104,7 +140,14 @@ FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter,
104140
unsigned argNumber = inductionVar.getArgNumber();
105141
Value afterTermIndArg = afterTerm.getResults()[argNumber];
106142

107-
Value inductionVarAfter = afterBody->getArgument(argNumber);
143+
auto findAfterArgNo = [](ArrayRef<unsigned> indices, unsigned beforeArgNo) {
144+
return std::distance(indices.begin(),
145+
llvm::find_if(indices, [beforeArgNo](unsigned n) {
146+
return n == beforeArgNo;
147+
}));
148+
};
149+
Value inductionVarAfter = afterBody->getArgument(
150+
argReorder ? findAfterArgNo(*argReorder, argNumber) : argNumber);
108151

109152
// Find suitable `addi` op inside `after` block, one of the args must be an
110153
// Induction var passed from `before` block and second arg must be defined
@@ -130,7 +173,7 @@ FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter,
130173
assert(lb.getType() == ub.getType());
131174
assert(lb.getType() == step.getType());
132175

133-
llvm::SmallVector<Value> newArgs;
176+
SmallVector<Value> newArgs;
134177

135178
// Populate inits for new `scf.for`, skip induction var.
136179
newArgs.reserve(loop.getInits().size());
@@ -164,6 +207,14 @@ FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter,
164207
newArgs.emplace_back(newBodyArgs[i]);
165208
}
166209
}
210+
if (argReorder) {
211+
// Reorder arguments following the 'after' argument order from the original
212+
// 'while' loop.
213+
SmallVector<Value> args;
214+
for (unsigned order : *argReorder)
215+
args.push_back(newArgs[order]);
216+
newArgs = args;
217+
}
167218

168219
rewriter.inlineBlockBefore(loop.getAfterBody(), newBody, newBody->end(),
169220
newArgs);
@@ -205,6 +256,14 @@ FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter,
205256
newArgs.clear();
206257
llvm::append_range(newArgs, newLoop.getResults());
207258
newArgs.insert(newArgs.begin() + argNumber, res);
259+
if (argReorder) {
260+
// Reorder arguments following the 'after' argument order from the original
261+
// 'while' loop.
262+
SmallVector<Value> results;
263+
for (unsigned order : *argReorder)
264+
results.push_back(newArgs[order]);
265+
newArgs = results;
266+
}
208267
rewriter.replaceOp(loop, newArgs);
209268
return newLoop;
210269
}

mlir/test/Dialect/SCF/uplift-while.mlir

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,3 +155,33 @@ func.func @uplift_while(%arg0: i64, %arg1: i64, %arg2: i64) -> i64 {
155155
// CHECK: %[[R6:.*]] = arith.muli %[[R5]], %[[STEP]] : i64
156156
// CHECK: %[[R7:.*]] = arith.addi %[[BEGIN]], %[[R6]] : i64
157157
// CHECK: return %[[R7]] : i64
158+
159+
// -----
160+
161+
// A case where all 'before' arguments are forwarded but reordered.
162+
func.func @uplift_while(%arg0: index, %arg1: index, %arg2: index) -> (i32, f32) {
163+
%c1 = arith.constant 1 : i32
164+
%c2 = arith.constant 2.0 : f32
165+
%0:3 = scf.while (%arg4 = %c1, %arg3 = %arg0, %arg5 = %c2) : (i32, index, f32) -> (index, i32, f32) {
166+
%1 = arith.cmpi slt, %arg3, %arg1 : index
167+
scf.condition(%1) %arg3, %arg4, %arg5 : index, i32, f32
168+
} do {
169+
^bb0(%arg3: index, %arg4: i32, %arg5: f32):
170+
%1 = "test.test1"(%arg4) : (i32) -> i32
171+
%added = arith.addi %arg3, %arg2 : index
172+
%2 = "test.test2"(%arg5) : (f32) -> f32
173+
scf.yield %1, %added, %2 : i32, index, f32
174+
}
175+
return %0#1, %0#2 : i32, f32
176+
}
177+
178+
// CHECK-LABEL: func @uplift_while
179+
// CHECK-SAME: (%[[BEGIN:.*]]: index, %[[END:.*]]: index, %[[STEP:.*]]: index) -> (i32, f32)
180+
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : i32
181+
// CHECK-DAG: %[[C2:.*]] = arith.constant 2.000000e+00 : f32
182+
// CHECK: %[[RES:.*]]:2 = scf.for %[[I:.*]] = %[[BEGIN]] to %[[END]] step %[[STEP]]
183+
// CHECK-SAME: iter_args(%[[ARG1:.*]] = %[[C1]], %[[ARG2:.*]] = %[[C2]]) -> (i32, f32) {
184+
// CHECK: %[[T1:.*]] = "test.test1"(%[[ARG1]]) : (i32) -> i32
185+
// CHECK: %[[T2:.*]] = "test.test2"(%[[ARG2]]) : (f32) -> f32
186+
// CHECK: scf.yield %[[T1]], %[[T2]] : i32, f32
187+
// CHECK: return %[[RES]]#0, %[[RES]]#1 : i32, f32

0 commit comments

Comments
 (0)