Skip to content

Commit 29f0357

Browse files
author
Jerry Wu
committed
Implement replaceWithZeroTripCheck for scf.while
1 parent 70f54b5 commit 29f0357

File tree

2 files changed

+107
-1
lines changed

2 files changed

+107
-1
lines changed

mlir/include/mlir/Dialect/SCF/IR/SCFOps.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -939,7 +939,9 @@ def WhileOp : SCF_Op<"while",
939939
[DeclareOpInterfaceMethods<RegionBranchOpInterface,
940940
["getEntrySuccessorOperands"]>,
941941
DeclareOpInterfaceMethods<LoopLikeOpInterface,
942-
["getRegionIterArgs", "getYieldedValuesMutable"]>,
942+
["getRegionIterArgs",
943+
"getYieldedValuesMutable",
944+
"replaceWithZeroTripCheck"]>,
943945
RecursiveMemoryEffects, SingleBlock]> {
944946
let summary = "a generic 'while' loop";
945947
let description = [{

mlir/lib/Dialect/SCF/IR/SCF.cpp

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3254,6 +3254,110 @@ LogicalResult scf::WhileOp::verify() {
32543254
return success(afterTerminator != nullptr);
32553255
}
32563256

3257+
/// Create zero-trip-check for a `while` op. Given an example below:
3258+
///
3259+
/// scf.while (%arg0 = %init) : (i32) -> i64 {
3260+
/// %val = .., %arg0 : i64
3261+
/// %cond = arith.cmpi .., %arg0 : i32
3262+
/// scf.condition(%cond) %val : i64
3263+
/// } do {
3264+
/// ^bb0(%arg1: i64):
3265+
/// %next = .., %arg1 : i32
3266+
/// scf.yield %next : i32
3267+
/// }
3268+
///
3269+
/// First clone before block to the front of the loop:
3270+
///
3271+
/// %val0 = .., %init : i64
3272+
/// %cond0 = arith.cmpi .., %init : i32
3273+
/// scf.while (%arg0 = %init) : (i32) -> i64 {
3274+
/// %val = .., %arg0 : i64
3275+
/// %cond = arith.cmpi .., %arg0 : i32
3276+
/// scf.condition(%cond) %val : i64
3277+
/// } do {
3278+
/// ^bb0(%arg1: i64):
3279+
/// %next = .., %arg1 : i32
3280+
/// scf.yield %next : i32
3281+
/// }
3282+
///
3283+
/// Create `if` op with the condition, rotate and move the loop into the else
3284+
/// branch:
3285+
///
3286+
/// %val0 = .., %init : i64
3287+
/// %cond0 = arith.cmpi .., %init : i32
3288+
/// scf.if %cond0 -> i64 {
3289+
/// %res = scf.while (%arg1 = %va0) : (i64) -> i64 {
3290+
/// // Original after block
3291+
/// %next = .., %arg1 : i32
3292+
/// // Original before block
3293+
/// %val = .., %next : i64
3294+
/// %cond = arith.cmpi .., %next : i32
3295+
/// scf.condition(%cond) %val : i64
3296+
/// } do {
3297+
/// ^bb0(%arg2: i64):
3298+
/// %scf.yield %arg2 : i32
3299+
/// }
3300+
/// scf.yield %res : i64
3301+
/// } else {
3302+
/// scf.yield %val0 : i64
3303+
/// }
3304+
FailureOr<LoopLikeOpInterface>
3305+
scf::WhileOp::replaceWithZeroTripCheck(RewriterBase &rewriter) {
3306+
IRMapping mapper;
3307+
Block *beforeBlock = this->getBeforeBody();
3308+
// Clone before block before the loop for zero-trip-check.
3309+
for (auto [arg, init] :
3310+
llvm::zip_equal(beforeBlock->getArguments(), this->getInits())) {
3311+
mapper.map(arg, init);
3312+
}
3313+
rewriter.setInsertionPoint(*this);
3314+
for (auto &op : *beforeBlock) {
3315+
if (isa<scf::ConditionOp>(op)) {
3316+
break;
3317+
}
3318+
// Safe to clone everything as in a single block all defs have been cloned
3319+
// and added to mapper in order.
3320+
rewriter.insert(op.clone(mapper));
3321+
}
3322+
3323+
auto condOp = this->getConditionOp();
3324+
auto clonedCondition = mapper.lookupOrDefault(condOp.getCondition());
3325+
auto clonedCondArgs = llvm::map_to_vector(
3326+
condOp.getArgs(), [&](Value arg) { return mapper.lookupOrDefault(arg); });
3327+
3328+
// Create zero-trip-check and move the while loop in.
3329+
scf::WhileOp newLoop = nullptr;
3330+
auto ifOp = rewriter.create<scf::IfOp>(
3331+
this->getLoc(), clonedCondition,
3332+
[&](OpBuilder &builder, Location loc) {
3333+
// Then runs the while loop.
3334+
newLoop = builder.create<scf::WhileOp>(
3335+
loc, this->getResultTypes(), clonedCondArgs,
3336+
[&](OpBuilder &builder, Location loc, ValueRange args) {
3337+
// Rotate and move the loop body into before block.
3338+
auto newBlock = builder.getBlock();
3339+
rewriter.mergeBlocks(this->getAfterBody(), newBlock, args);
3340+
auto yieldOp = cast<scf::YieldOp>(newBlock->getTerminator());
3341+
rewriter.mergeBlocks(this->getBeforeBody(), newBlock,
3342+
yieldOp.getResults());
3343+
rewriter.eraseOp(yieldOp);
3344+
},
3345+
[&](OpBuilder &builder, Location loc, ValueRange args) {
3346+
// Pass-through values in after block.
3347+
builder.create<scf::YieldOp>(loc, args);
3348+
});
3349+
builder.create<scf::YieldOp>(loc, newLoop.getResults());
3350+
},
3351+
[&](OpBuilder &builder, Location loc) {
3352+
// Else returns the results from zero-trip-check.
3353+
builder.create<scf::YieldOp>(loc, clonedCondArgs);
3354+
});
3355+
3356+
rewriter.replaceOp(*this, ifOp);
3357+
3358+
return cast<LoopLikeOpInterface>(newLoop.getOperation());
3359+
}
3360+
32573361
namespace {
32583362
/// Replace uses of the condition within the do block with true, since otherwise
32593363
/// the block would not be evaluated.

0 commit comments

Comments
 (0)