Skip to content

Commit 1a3d625

Browse files
committed
[mlir][scf] Add simple LICM pattern for scf.while
Move non-side-effecting ops from `before` region if all their args are defined outside the loop. This is cleanup needed for `scf.while` -> `scf.for` uplifting llvm#76108 as it expects `before` block consisting of single `cmp` op.
1 parent d980384 commit 1a3d625

File tree

2 files changed

+57
-4
lines changed

2 files changed

+57
-4
lines changed

mlir/lib/Dialect/SCF/IR/SCF.cpp

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3810,6 +3810,36 @@ struct WhileRemoveUnusedArgs : public OpRewritePattern<WhileOp> {
38103810
}
38113811
};
38123812

3813+
/// Simple Loop Invariant Code Motion pattern for `scf.while` op.
3814+
/// `scf.while` to `scf.for` uplifting expects `before` block consisting of
3815+
/// single `cmp` op.
3816+
/// Pattern moves ops from `before` block, doesn't visit nested regions.
3817+
struct SCFWhileLICM : public OpRewritePattern<WhileOp> {
3818+
using OpRewritePattern::OpRewritePattern;
3819+
3820+
LogicalResult matchAndRewrite(WhileOp loop,
3821+
PatternRewriter &rewriter) const override {
3822+
bool changed = false;
3823+
3824+
DominanceInfo dom;
3825+
Block *body = loop.getBeforeBody();
3826+
for (Operation &op :
3827+
llvm::make_early_inc_range(body->without_terminator())) {
3828+
if (llvm::any_of(op.getOperands(), [&](Value arg) {
3829+
return !dom.properlyDominates(arg, loop);
3830+
}))
3831+
continue;
3832+
3833+
if (!isMemoryEffectFree(&op))
3834+
continue;
3835+
3836+
rewriter.updateRootInPlace(&op, [&]() { op.moveBefore(loop); });
3837+
changed = true;
3838+
}
3839+
return success(changed);
3840+
}
3841+
};
3842+
38133843
/// Remove duplicated ConditionOp args.
38143844
struct WhileRemoveDuplicatedResults : public OpRewritePattern<WhileOp> {
38153845
using OpRewritePattern::OpRewritePattern;
@@ -3879,7 +3909,7 @@ void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,
38793909
results.add<RemoveLoopInvariantArgsFromBeforeBlock,
38803910
RemoveLoopInvariantValueYielded, WhileConditionTruth,
38813911
WhileCmpCond, WhileUnusedResult, WhileRemoveDuplicatedResults,
3882-
WhileRemoveUnusedArgs>(context);
3912+
SCFWhileLICM, WhileRemoveUnusedArgs>(context);
38833913
}
38843914

38853915
//===----------------------------------------------------------------------===//

mlir/test/Dialect/SCF/canonicalize.mlir

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1022,10 +1022,10 @@ func.func @while_loop_invariant_argument_different_order(%arg : tensor<i32>) ->
10221022
// CHECK-SAME: (%[[ARG:.+]]: tensor<i32>)
10231023
// CHECK: %[[ZERO:.*]] = arith.constant dense<0>
10241024
// CHECK: %[[ONE:.*]] = arith.constant dense<1>
1025+
// CHECK: %[[COND:.*]] = arith.cmpi sgt, %[[ARG]], %[[ZERO]]
1026+
// CHECK: %[[COND1:.*]] = tensor.extract %[[COND]][]
10251027
// CHECK: %[[WHILE:.*]]:2 = scf.while (%[[ARG1:.*]] = %[[ONE]], %[[ARG4:.*]] = %[[ZERO]])
1026-
// CHECK: arith.cmpi sgt, %[[ARG]], %[[ZERO]]
1027-
// CHECK: tensor.extract %{{.*}}[]
1028-
// CHECK: scf.condition(%{{.*}}) %[[ARG1]], %[[ARG4]]
1028+
// CHECK: scf.condition(%[[COND1]]) %[[ARG1]], %[[ARG4]]
10291029
// CHECK: } do {
10301030
// CHECK: ^{{.*}}(%{{.*}}: tensor<i32>, %{{.*}}: tensor<i32>):
10311031
// CHECK: scf.yield %[[ZERO]], %[[ONE]]
@@ -1144,6 +1144,29 @@ func.func @while_duplicated_res() -> (i32, i32) {
11441144
// CHECK: }
11451145
// CHECK: return %[[RES]], %[[RES]] : i32, i32
11461146

1147+
// -----
1148+
1149+
func.func @while_licm(%arg1: i32, %arg2: i32, %arg3: i32) {
1150+
scf.while () : () -> () {
1151+
%val0 = arith.addi %arg1, %arg2 : i32
1152+
%val = arith.addi %val0, %arg3 : i32
1153+
%condition = "test.condition"(%val) : (i32) -> i1
1154+
scf.condition(%condition)
1155+
} do {
1156+
^bb0():
1157+
"test.test"() : () -> ()
1158+
scf.yield
1159+
}
1160+
return
1161+
}
1162+
// CHECK-LABEL: @while_licm
1163+
// CHECK-SAME: (%[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32)
1164+
// CHECK: %[[VAL0:.*]] = arith.addi %[[ARG1]], %[[ARG2]] : i32
1165+
// CHECK: %[[VAL1:.*]] = arith.addi %[[VAL0]], %[[ARG3]] : i32
1166+
// CHECK: scf.while
1167+
// CHECK-NEXT: %[[COND:.*]] = "test.condition"(%[[VAL1]]) : (i32) -> i1
1168+
// CHECK-NEXT: scf.condition(%[[COND]])
1169+
11471170

11481171
// -----
11491172

0 commit comments

Comments
 (0)