Skip to content

Commit 1089292

Browse files
committed
[mlir][scf] Align scf.while before block args in canonicalizer
If `before` block args are directly forwarded to `scf.condition` make sure they are passes in the same order. This is needed for `scf.while` uplifting llvm#76108
1 parent 5b66b6a commit 1089292

File tree

2 files changed

+105
-1
lines changed

2 files changed

+105
-1
lines changed

mlir/lib/Dialect/SCF/IR/SCF.cpp

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3884,14 +3884,89 @@ struct WhileRemoveDuplicatedResults : public OpRewritePattern<WhileOp> {
38843884
return success();
38853885
}
38863886
};
3887+
3888+
/// If both ranges contain same values return mappping indices from args1 to
3889+
/// args2. Otherwise return std::nullopt
3890+
static std::optional<SmallVector<unsigned>> getArgsMapping(ValueRange args1,
3891+
ValueRange args2) {
3892+
if (args1.size() != args2.size())
3893+
return std::nullopt;
3894+
3895+
SmallVector<unsigned> ret(args1.size());
3896+
for (auto &&[i, arg1] : llvm::enumerate(args1)) {
3897+
auto it = llvm::find(args2, arg1);
3898+
if (it == args2.end())
3899+
return std::nullopt;
3900+
3901+
auto j = it - args2.begin();
3902+
ret[j] = static_cast<unsigned>(i);
3903+
}
3904+
3905+
return ret;
3906+
}
3907+
3908+
/// If `before` block args are directly forwarded to `scf.condition`, rearrange
3909+
/// `scf.condition` args into same order as block args. Update `after` block
3910+
// args and results values accordingly.
3911+
/// Needed to simplify `scf.while` -> `scf.for` uplifting.
3912+
struct WhileOpAlignBeforeArgs : public OpRewritePattern<WhileOp> {
3913+
using OpRewritePattern::OpRewritePattern;
3914+
3915+
LogicalResult matchAndRewrite(WhileOp loop,
3916+
PatternRewriter &rewriter) const override {
3917+
auto oldBefore = loop.getBeforeBody();
3918+
ConditionOp oldTerm = loop.getConditionOp();
3919+
ValueRange beforeArgs = oldBefore->getArguments();
3920+
ValueRange termArgs = oldTerm.getArgs();
3921+
if (beforeArgs == termArgs)
3922+
return failure();
3923+
3924+
auto mapping = getArgsMapping(beforeArgs, termArgs);
3925+
if (!mapping)
3926+
return failure();
3927+
3928+
{
3929+
OpBuilder::InsertionGuard g(rewriter);
3930+
rewriter.setInsertionPoint(oldTerm);
3931+
rewriter.replaceOpWithNewOp<ConditionOp>(oldTerm, oldTerm.getCondition(),
3932+
beforeArgs);
3933+
}
3934+
3935+
auto oldAfter = loop.getAfterBody();
3936+
3937+
SmallVector<Type> newResultTypes(beforeArgs.size());
3938+
for (auto &&[i, j] : llvm::enumerate(*mapping))
3939+
newResultTypes[j] = loop.getResult(i).getType();
3940+
3941+
auto newLoop = rewriter.create<WhileOp>(loop.getLoc(), newResultTypes,
3942+
loop.getInits(), nullptr, nullptr);
3943+
auto newBefore = newLoop.getBeforeBody();
3944+
auto newAfter = newLoop.getAfterBody();
3945+
3946+
SmallVector<Value> newResults(beforeArgs.size());
3947+
SmallVector<Value> newAfterArgs(beforeArgs.size());
3948+
for (auto &&[i, j] : llvm::enumerate(*mapping)) {
3949+
newResults[i] = newLoop.getResult(j);
3950+
newAfterArgs[i] = newAfter->getArgument(j);
3951+
}
3952+
3953+
rewriter.inlineBlockBefore(oldBefore, newBefore, newBefore->begin(),
3954+
newBefore->getArguments());
3955+
rewriter.inlineBlockBefore(oldAfter, newAfter, newAfter->begin(),
3956+
newAfterArgs);
3957+
3958+
rewriter.replaceOp(loop, newResults);
3959+
return success();
3960+
}
3961+
};
38873962
} // namespace
38883963

38893964
void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,
38903965
MLIRContext *context) {
38913966
results.add<RemoveLoopInvariantArgsFromBeforeBlock,
38923967
RemoveLoopInvariantValueYielded, WhileConditionTruth,
38933968
WhileCmpCond, WhileUnusedResult, WhileRemoveDuplicatedResults,
3894-
WhileRemoveUnusedArgs>(context);
3969+
WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs>(context);
38953970
}
38963971

38973972
//===----------------------------------------------------------------------===//

mlir/test/Dialect/SCF/canonicalize.mlir

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1198,6 +1198,35 @@ func.func @while_unused_arg2(%val0: i32) -> i32 {
11981198
// CHECK: return %[[RES]] : i32
11991199

12001200

1201+
// -----
1202+
1203+
// CHECK-LABEL: func @test_align_args
1204+
// CHECK: %[[RES:.*]]:3 = scf.while (%[[ARG0:.*]] = %{{.*}}, %[[ARG1:.*]] = %{{.*}}, %[[ARG2:.*]] = %{{.*}}) : (f32, i32, i64) -> (f32, i32, i64) {
1205+
// CHECK: scf.condition(%{{.*}}) %[[ARG0]], %[[ARG1]], %[[ARG2]] : f32, i32, i64
1206+
// CHECK: ^bb0(%[[ARG3:.*]]: f32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i64):
1207+
// CHECK: %[[R1:.*]] = "test.test"(%[[ARG5]]) : (i64) -> f32
1208+
// CHECK: %[[R2:.*]] = "test.test"(%[[ARG3]]) : (f32) -> i32
1209+
// CHECK: %[[R3:.*]] = "test.test"(%[[ARG4]]) : (i32) -> i64
1210+
// CHECK: scf.yield %[[R1]], %[[R2]], %[[R3]] : f32, i32, i64
1211+
// CHECK: return %[[RES]]#2, %[[RES]]#0, %[[RES]]#1
1212+
func.func @test_align_args() -> (i64, f32, i32) {
1213+
%0 = "test.test"() : () -> (f32)
1214+
%1 = "test.test"() : () -> (i32)
1215+
%2 = "test.test"() : () -> (i64)
1216+
%3:3 = scf.while (%arg0 = %0, %arg1 = %1, %arg2 = %2) : (f32, i32, i64) -> (i64, f32, i32) {
1217+
%cond = "test.test"() : () -> (i1)
1218+
scf.condition(%cond) %arg2, %arg0, %arg1 : i64, f32, i32
1219+
} do {
1220+
^bb0(%arg3: i64, %arg4: f32, %arg5: i32):
1221+
%4 = "test.test"(%arg3) : (i64) -> (f32)
1222+
%5 = "test.test"(%arg4) : (f32) -> (i32)
1223+
%6 = "test.test"(%arg5) : (i32) -> (i64)
1224+
scf.yield %4, %5, %6 : f32, i32, i64
1225+
}
1226+
return %3#0, %3#1, %3#2 : i64, f32, i32
1227+
}
1228+
1229+
12011230
// -----
12021231

12031232
// CHECK-LABEL: @combineIfs

0 commit comments

Comments
 (0)