Skip to content

Commit 4426258

Browse files
author
Jerry Wu
committed
Implement replaceWithZeroTripCheck for scf.while
1 parent b3cf941 commit 4426258

File tree

3 files changed

+131
-6
lines changed

3 files changed

+131
-6
lines changed

mlir/include/mlir/Dialect/SCF/IR/SCFOps.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -939,7 +939,9 @@ def WhileOp : SCF_Op<"while",
939939
[DeclareOpInterfaceMethods<RegionBranchOpInterface,
940940
["getEntrySuccessorOperands"]>,
941941
DeclareOpInterfaceMethods<LoopLikeOpInterface,
942-
["getRegionIterArgs", "getYieldedValuesMutable"]>,
942+
["getRegionIterArgs",
943+
"getYieldedValuesMutable",
944+
"replaceWithZeroTripCheck"]>,
943945
RecursiveMemoryEffects, SingleBlock]> {
944946
let summary = "a generic 'while' loop";
945947
let description = [{

mlir/lib/Dialect/SCF/IR/SCF.cpp

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3254,6 +3254,110 @@ LogicalResult scf::WhileOp::verify() {
32543254
return success(afterTerminator != nullptr);
32553255
}
32563256

3257+
/// Create zero-trip-check for a `while` op. Given an example below:
3258+
///
3259+
/// scf.while (%arg0 = %init) : (i32) -> i64 {
3260+
/// %val = .., %arg0 : i64
3261+
/// %cond = arith.cmpi .., %arg0 : i32
3262+
/// scf.condition(%cond) %val : i64
3263+
/// } do {
3264+
/// ^bb0(%arg1: i64):
3265+
/// %next = .., %arg1 : i32
3266+
/// scf.yield %next : i32
3267+
/// }
3268+
///
3269+
/// First clone before block to the front of the loop:
3270+
///
3271+
/// %pre_val = .., %init : i64
3272+
/// %pre_cond = arith.cmpi .., %init : i32
3273+
/// scf.while (%arg0 = %init) : (i32) -> i64 {
3274+
/// %val = .., %arg0 : i64
3275+
/// %cond = arith.cmpi .., %arg0 : i32
3276+
/// scf.condition(%cond) %val : i64
3277+
/// } do {
3278+
/// ^bb0(%arg1: i64):
3279+
/// %next = .., %arg1 : i32
3280+
/// scf.yield %next : i32
3281+
/// }
3282+
///
3283+
/// Create `if` op with the condition, rotate and move the loop into the else
3284+
/// branch:
3285+
///
3286+
/// %pre_val = .., %init : i64
3287+
/// %pre_cond = arith.cmpi .., %init : i32
3288+
/// scf.if %pre_cond -> i64 {
3289+
/// %res = scf.while (%arg1 = %va0) : (i64) -> i64 {
3290+
/// // Original after block
3291+
/// %next = .., %arg1 : i32
3292+
/// // Original before block
3293+
/// %val = .., %next : i64
3294+
/// %cond = arith.cmpi .., %next : i32
3295+
/// scf.condition(%cond) %val : i64
3296+
/// } do {
3297+
/// ^bb0(%arg2: i64):
3298+
/// %scf.yield %arg2 : i32
3299+
/// }
3300+
/// scf.yield %res : i64
3301+
/// } else {
3302+
/// scf.yield %pre_val : i64
3303+
/// }
3304+
FailureOr<LoopLikeOpInterface>
3305+
scf::WhileOp::replaceWithZeroTripCheck(RewriterBase &rewriter) {
3306+
IRMapping mapper;
3307+
Block *beforeBlock = this->getBeforeBody();
3308+
// Clone before block before the loop for zero-trip-check.
3309+
for (auto [arg, init] :
3310+
llvm::zip_equal(beforeBlock->getArguments(), this->getInits())) {
3311+
mapper.map(arg, init);
3312+
}
3313+
rewriter.setInsertionPoint(*this);
3314+
for (auto &op : *beforeBlock) {
3315+
if (isa<scf::ConditionOp>(op)) {
3316+
break;
3317+
}
3318+
// Safe to clone everything as in a single block all defs have been cloned
3319+
// and added to mapper in order.
3320+
rewriter.insert(op.clone(mapper));
3321+
}
3322+
3323+
auto condOp = this->getConditionOp();
3324+
auto clonedCondition = mapper.lookupOrDefault(condOp.getCondition());
3325+
auto clonedCondArgs = llvm::map_to_vector(
3326+
condOp.getArgs(), [&](Value arg) { return mapper.lookupOrDefault(arg); });
3327+
3328+
// Create zero-trip-check and move the while loop in.
3329+
scf::WhileOp newLoop = nullptr;
3330+
auto ifOp = rewriter.create<scf::IfOp>(
3331+
this->getLoc(), clonedCondition,
3332+
[&](OpBuilder &builder, Location loc) {
3333+
// Then runs the while loop.
3334+
newLoop = builder.create<scf::WhileOp>(
3335+
loc, this->getResultTypes(), clonedCondArgs,
3336+
[&](OpBuilder &builder, Location loc, ValueRange args) {
3337+
// Rotate and move the loop body into before block.
3338+
auto newBlock = builder.getBlock();
3339+
rewriter.mergeBlocks(this->getAfterBody(), newBlock, args);
3340+
auto yieldOp = cast<scf::YieldOp>(newBlock->getTerminator());
3341+
rewriter.mergeBlocks(this->getBeforeBody(), newBlock,
3342+
yieldOp.getResults());
3343+
rewriter.eraseOp(yieldOp);
3344+
},
3345+
[&](OpBuilder &builder, Location loc, ValueRange args) {
3346+
// Pass-through values.
3347+
builder.create<scf::YieldOp>(loc, args);
3348+
});
3349+
builder.create<scf::YieldOp>(loc, newLoop.getResults());
3350+
},
3351+
[&](OpBuilder &builder, Location loc) {
3352+
// Else returns the results from zero-trip-check.
3353+
builder.create<scf::YieldOp>(loc, clonedCondArgs);
3354+
});
3355+
3356+
rewriter.replaceOp(*this, ifOp);
3357+
3358+
return cast<LoopLikeOpInterface>(newLoop.getOperation());
3359+
}
3360+
32573361
namespace {
32583362
/// Replace uses of the condition within the do block with true, since otherwise
32593363
/// the block would not be evaluated.

mlir/test/Dialect/SCF/loop-zero-trip-check.mlir

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
// RUN: mlir-opt %s -test-loop-zero-trip-check -split-input-file | FileCheck %s
22

3-
func.func @no_replace_scf_while_with_zero_trip_check(%bound : i32) -> i32 {
3+
func.func @replace_scf_while_with_zero_trip_check(%bound : i32) -> i32 {
44
%cst0 = arith.constant 0 : i32
55
%cst5 = arith.constant 5 : i32
66
%res:2 = scf.while (%iter = %cst0) : (i32) -> (i32, i32) {
@@ -15,7 +15,26 @@ func.func @no_replace_scf_while_with_zero_trip_check(%bound : i32) -> i32 {
1515
return %res#0 : i32
1616
}
1717

18-
// TODO(pzread): Update the test once the replaceZeroTripCheck is implemented.
19-
// CHECK-LABEL: func.func @no_replace_scf_while_with_zero_trip_check
20-
// CHECK-NOT: scf.if
21-
// CHECK: scf.while
18+
// CHECK-LABEL: func.func @replace_scf_while_with_zero_trip_check(
19+
// CHECK-SAME: %[[ARG0:.*]]: i32) -> i32 {
20+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : i32
21+
// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : i32
22+
// CHECK-DAG: %[[PRE_COND:.*]] = arith.cmpi slt, %[[C0]], %[[ARG0]] : i32
23+
// CHECK-DAG: %[[PRE_INV:.*]] = arith.addi %[[ARG0]], %[[C5]] : i32
24+
// CHECK: %[[IF:.*]]:2 = scf.if %[[PRE_COND]] -> (i32, i32) {
25+
// CHECK: %[[WHILE:.*]]:2 = scf.while (
26+
// CHECK-SAME: %[[ARG1:.*]] = %[[C0]], %[[ARG2:.*]] = %[[PRE_INV]]
27+
// CHECK-SAME: ) : (i32, i32) -> (i32, i32) {
28+
// CHECK: %[[NEXT:.*]] = arith.addi %[[ARG1]], %[[ARG2]] : i32
29+
// CHECK: %[[COND:.*]] = arith.cmpi slt, %[[NEXT]], %[[ARG0]] : i32
30+
// CHECK: %[[INV:.*]] = arith.addi %[[ARG0]], %[[C5]] : i32
31+
// CHECK: scf.condition(%[[COND]]) %[[NEXT]], %[[INV]] : i32, i32
32+
// CHECK: } do {
33+
// CHECK: ^bb0(%[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32):
34+
// CHECK: scf.yield %[[ARG3]], %[[ARG4]] : i32, i32
35+
// CHECK: }
36+
// CHECK: scf.yield %[[WHILE]]#0, %[[WHILE]]#1 : i32, i32
37+
// CHECK: } else {
38+
// CHECK: scf.yield %[[C0]], %[[PRE_INV]] : i32, i32
39+
// CHECK: }
40+
// CHECK: return %[[IF]]#0 : i32

0 commit comments

Comments
 (0)