Skip to content

Commit 16d1144

Browse files
committed
[mlir][scf] Align scf.while before block args in canonicalizer
If `before` block args are directly forwarded to `scf.condition` make sure they are passes in the same order. This is needed for `scf.while` uplifting #76108
1 parent d980384 commit 16d1144

File tree

2 files changed

+105
-1
lines changed

2 files changed

+105
-1
lines changed

mlir/lib/Dialect/SCF/IR/SCF.cpp

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3872,14 +3872,89 @@ struct WhileRemoveDuplicatedResults : public OpRewritePattern<WhileOp> {
38723872
return success();
38733873
}
38743874
};
3875+
3876+
/// If both ranges contain same values return mappping indices from args1 to
3877+
/// args2. Otherwise return std::nullopt
3878+
static std::optional<SmallVector<unsigned>> getArgsMapping(ValueRange args1,
3879+
ValueRange args2) {
3880+
if (args1.size() != args2.size())
3881+
return std::nullopt;
3882+
3883+
SmallVector<unsigned> ret(args1.size());
3884+
for (auto &&[i, arg1] : llvm::enumerate(args1)) {
3885+
auto it = llvm::find(args2, arg1);
3886+
if (it == args2.end())
3887+
return std::nullopt;
3888+
3889+
auto j = it - args2.begin();
3890+
ret[j] = static_cast<unsigned>(i);
3891+
}
3892+
3893+
return ret;
3894+
}
3895+
3896+
/// If `before` block args are directly forwarded to `scf.condition`, rearrange
3897+
/// `scf.condition` args into same order as block args. Update `after` block
3898+
// args and results values accordingly.
3899+
/// Needed to simplify `scf.while` -> `scf.for` uplifting.
3900+
struct WhileOpAlignBeforeArgs : public OpRewritePattern<WhileOp> {
3901+
using OpRewritePattern::OpRewritePattern;
3902+
3903+
LogicalResult matchAndRewrite(WhileOp loop,
3904+
PatternRewriter &rewriter) const override {
3905+
auto oldBefore = loop.getBeforeBody();
3906+
ConditionOp oldTerm = loop.getConditionOp();
3907+
ValueRange beforeArgs = oldBefore->getArguments();
3908+
ValueRange termArgs = oldTerm.getArgs();
3909+
if (beforeArgs == termArgs)
3910+
return failure();
3911+
3912+
auto mapping = getArgsMapping(beforeArgs, termArgs);
3913+
if (!mapping)
3914+
return failure();
3915+
3916+
{
3917+
OpBuilder::InsertionGuard g(rewriter);
3918+
rewriter.setInsertionPoint(oldTerm);
3919+
rewriter.replaceOpWithNewOp<ConditionOp>(oldTerm, oldTerm.getCondition(),
3920+
beforeArgs);
3921+
}
3922+
3923+
auto oldAfter = loop.getAfterBody();
3924+
3925+
SmallVector<Type> newResultTypes(beforeArgs.size());
3926+
for (auto &&[i, j] : llvm::enumerate(*mapping))
3927+
newResultTypes[j] = loop.getResult(i).getType();
3928+
3929+
auto newLoop = rewriter.create<WhileOp>(loop.getLoc(), newResultTypes,
3930+
loop.getInits(), nullptr, nullptr);
3931+
auto newBefore = newLoop.getBeforeBody();
3932+
auto newAfter = newLoop.getAfterBody();
3933+
3934+
SmallVector<Value> newResults(beforeArgs.size());
3935+
SmallVector<Value> newAfterArgs(beforeArgs.size());
3936+
for (auto &&[i, j] : llvm::enumerate(*mapping)) {
3937+
newResults[i] = newLoop.getResult(j);
3938+
newAfterArgs[i] = newAfter->getArgument(j);
3939+
}
3940+
3941+
rewriter.inlineBlockBefore(oldBefore, newBefore, newBefore->begin(),
3942+
newBefore->getArguments());
3943+
rewriter.inlineBlockBefore(oldAfter, newAfter, newAfter->begin(),
3944+
newAfterArgs);
3945+
3946+
rewriter.replaceOp(loop, newResults);
3947+
return success();
3948+
}
3949+
};
38753950
} // namespace
38763951

38773952
void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,
38783953
MLIRContext *context) {
38793954
results.add<RemoveLoopInvariantArgsFromBeforeBlock,
38803955
RemoveLoopInvariantValueYielded, WhileConditionTruth,
38813956
WhileCmpCond, WhileUnusedResult, WhileRemoveDuplicatedResults,
3882-
WhileRemoveUnusedArgs>(context);
3957+
WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs>(context);
38833958
}
38843959

38853960
//===----------------------------------------------------------------------===//

mlir/test/Dialect/SCF/canonicalize.mlir

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1198,6 +1198,35 @@ func.func @while_unused_arg2(%val0: i32) -> i32 {
11981198
// CHECK: return %[[RES]] : i32
11991199

12001200

1201+
// -----
1202+
1203+
// CHECK-LABEL: func @test_align_args
1204+
// CHECK: %[[RES:.*]]:3 = scf.while (%[[ARG0:.*]] = %{{.*}}, %[[ARG1:.*]] = %{{.*}}, %[[ARG2:.*]] = %{{.*}}) : (f32, i32, i64) -> (f32, i32, i64) {
1205+
// CHECK: scf.condition(%{{.*}}) %[[ARG0]], %[[ARG1]], %[[ARG2]] : f32, i32, i64
1206+
// CHECK: ^bb0(%[[ARG3:.*]]: f32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i64):
1207+
// CHECK: %[[R1:.*]] = "test.test"(%[[ARG5]]) : (i64) -> f32
1208+
// CHECK: %[[R2:.*]] = "test.test"(%[[ARG3]]) : (f32) -> i32
1209+
// CHECK: %[[R3:.*]] = "test.test"(%[[ARG4]]) : (i32) -> i64
1210+
// CHECK: scf.yield %[[R1]], %[[R2]], %[[R3]] : f32, i32, i64
1211+
// CHECK: return %[[RES]]#2, %[[RES]]#0, %[[RES]]#1
1212+
func.func @test_align_args() -> (i64, f32, i32) {
1213+
%0 = "test.test"() : () -> (f32)
1214+
%1 = "test.test"() : () -> (i32)
1215+
%2 = "test.test"() : () -> (i64)
1216+
%3:3 = scf.while (%arg0 = %0, %arg1 = %1, %arg2 = %2) : (f32, i32, i64) -> (i64, f32, i32) {
1217+
%cond = "test.test"() : () -> (i1)
1218+
scf.condition(%cond) %arg2, %arg0, %arg1 : i64, f32, i32
1219+
} do {
1220+
^bb0(%arg3: i64, %arg4: f32, %arg5: i32):
1221+
%4 = "test.test"(%arg3) : (i64) -> (f32)
1222+
%5 = "test.test"(%arg4) : (f32) -> (i32)
1223+
%6 = "test.test"(%arg5) : (i32) -> (i64)
1224+
scf.yield %4, %5, %6 : f32, i32, i64
1225+
}
1226+
return %3#0, %3#1, %3#2 : i64, f32, i32
1227+
}
1228+
1229+
12011230
// -----
12021231

12031232
// CHECK-LABEL: @combineIfs

0 commit comments

Comments
 (0)