Skip to content

Commit 8475fa6

Browse files
committed
[mlir] Add a simpler lowering pattern for WhileOp representing a do-while loop
When the "after" region of a WhileOp is merely forwarding its arguments back to the "before" region, i.e. WhileOp is a canonical do-while loop, a simpler CFG subgraph that omits the "after" region with its extra branch operation can be produced. Loop rotation from general "while" to "if { do-while }" is left for a future canonicalization pattern when it becomes necessary. Differential Revision: https://reviews.llvm.org/D90604
1 parent 4c0e255 commit 8475fa6

File tree

2 files changed

+82
-0
lines changed

2 files changed

+82
-0
lines changed

mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,17 @@ struct WhileLowering : public OpRewritePattern<WhileOp> {
266266
LogicalResult matchAndRewrite(WhileOp whileOp,
267267
PatternRewriter &rewriter) const override;
268268
};
269+
270+
/// Optimized version of the above for the case of the "after" region merely
271+
/// forwarding its arguments back to the "before" region (i.e., a "do-while"
272+
/// loop). This avoid inlining the "after" region completely and branches back
273+
/// to the "before" entry instead.
274+
struct DoWhileLowering : public OpRewritePattern<WhileOp> {
275+
using OpRewritePattern<WhileOp>::OpRewritePattern;
276+
277+
LogicalResult matchAndRewrite(WhileOp whileOp,
278+
PatternRewriter &rewriter) const override;
279+
};
269280
} // namespace
270281

271282
LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
@@ -507,10 +518,60 @@ LogicalResult WhileLowering::matchAndRewrite(WhileOp whileOp,
507518
return success();
508519
}
509520

521+
LogicalResult
522+
DoWhileLowering::matchAndRewrite(WhileOp whileOp,
523+
PatternRewriter &rewriter) const {
524+
if (!llvm::hasSingleElement(whileOp.after()))
525+
return rewriter.notifyMatchFailure(whileOp,
526+
"do-while simplification applicable to "
527+
"single-block 'after' region only");
528+
529+
Block &afterBlock = whileOp.after().front();
530+
if (!llvm::hasSingleElement(afterBlock))
531+
return rewriter.notifyMatchFailure(whileOp,
532+
"do-while simplification applicable "
533+
"only if 'after' region has no payload");
534+
535+
auto yield = dyn_cast<scf::YieldOp>(&afterBlock.front());
536+
if (!yield || yield.results() != afterBlock.getArguments())
537+
return rewriter.notifyMatchFailure(whileOp,
538+
"do-while simplification applicable "
539+
"only to forwarding 'after' regions");
540+
541+
// Split the current block before the WhileOp to create the inlining point.
542+
OpBuilder::InsertionGuard guard(rewriter);
543+
Block *currentBlock = rewriter.getInsertionBlock();
544+
Block *continuation =
545+
rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
546+
547+
// Only the "before" region should be inlined.
548+
Block *before = &whileOp.before().front();
549+
Block *beforeLast = &whileOp.before().back();
550+
rewriter.inlineRegionBefore(whileOp.before(), continuation);
551+
552+
// Branch to the "before" region.
553+
rewriter.setInsertionPointToEnd(currentBlock);
554+
rewriter.create<BranchOp>(whileOp.getLoc(), before, whileOp.inits());
555+
556+
// Loop around the "before" region based on condition.
557+
rewriter.setInsertionPointToEnd(beforeLast);
558+
auto condOp = cast<ConditionOp>(beforeLast->getTerminator());
559+
rewriter.replaceOpWithNewOp<CondBranchOp>(condOp, condOp.condition(), before,
560+
condOp.args(), continuation,
561+
ValueRange());
562+
563+
// Replace the op with values "yielded" from the "before" region, which are
564+
// visible by dominance.
565+
rewriter.replaceOp(whileOp, condOp.args());
566+
567+
return success();
568+
}
569+
510570
void mlir::populateLoopToStdConversionPatterns(
511571
OwningRewritePatternList &patterns, MLIRContext *ctx) {
512572
patterns.insert<ForLowering, IfLowering, ParallelLowering, WhileLowering>(
513573
ctx);
574+
patterns.insert<DoWhileLowering>(ctx, /*benefit=*/2);
514575
}
515576

516577
void SCFToStandardPass::runOnOperation() {

mlir/test/Conversion/SCFToStandard/convert-to-cfg.mlir

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,8 @@ func @minimal_while() {
424424
scf.condition(%0)
425425
} do {
426426
// CHECK: ^[[AFTER]]:
427+
// CHECK: "test.some_payload"() : () -> ()
428+
"test.some_payload"() : () -> ()
427429
// CHECK: br ^[[BEFORE]]
428430
scf.yield
429431
}
@@ -432,6 +434,25 @@ func @minimal_while() {
432434
return
433435
}
434436

437+
// CHECK-LABEL: @do_while
438+
func @do_while(%arg0: f32) {
439+
// CHECK: br ^[[BEFORE:.*]]({{.*}}: f32)
440+
scf.while (%arg1 = %arg0) : (f32) -> (f32) {
441+
// CHECK: ^[[BEFORE]](%[[VAL:.*]]: f32):
442+
// CHECK: %[[COND:.*]] = "test.make_condition"() : () -> i1
443+
%0 = "test.make_condition"() : () -> i1
444+
// CHECK: cond_br %[[COND]], ^[[BEFORE]](%[[VAL]] : f32), ^[[CONT:.*]]
445+
scf.condition(%0) %arg1 : f32
446+
} do {
447+
^bb0(%arg2: f32):
448+
// CHECK-NOT: br ^[[BEFORE]]
449+
scf.yield %arg2 : f32
450+
}
451+
// CHECK: ^[[CONT]]:
452+
// CHECK: return
453+
return
454+
}
455+
435456
// CHECK-LABEL: @while_values
436457
// CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: f32)
437458
func @while_values(%arg0: i32, %arg1: f32) {

0 commit comments

Comments
 (0)