@@ -266,6 +266,17 @@ struct WhileLowering : public OpRewritePattern<WhileOp> {
266
266
LogicalResult matchAndRewrite (WhileOp whileOp,
267
267
PatternRewriter &rewriter) const override ;
268
268
};
269
+
270
+ // / Optimized version of the above for the case of the "after" region merely
271
+ // / forwarding its arguments back to the "before" region (i.e., a "do-while"
272
+ // / loop). This avoid inlining the "after" region completely and branches back
273
+ // / to the "before" entry instead.
274
+ struct DoWhileLowering : public OpRewritePattern <WhileOp> {
275
+ using OpRewritePattern<WhileOp>::OpRewritePattern;
276
+
277
+ LogicalResult matchAndRewrite (WhileOp whileOp,
278
+ PatternRewriter &rewriter) const override ;
279
+ };
269
280
} // namespace
270
281
271
282
LogicalResult ForLowering::matchAndRewrite (ForOp forOp,
@@ -507,10 +518,60 @@ LogicalResult WhileLowering::matchAndRewrite(WhileOp whileOp,
507
518
return success ();
508
519
}
509
520
521
+ LogicalResult
522
+ DoWhileLowering::matchAndRewrite (WhileOp whileOp,
523
+ PatternRewriter &rewriter) const {
524
+ if (!llvm::hasSingleElement (whileOp.after ()))
525
+ return rewriter.notifyMatchFailure (whileOp,
526
+ " do-while simplification applicable to "
527
+ " single-block 'after' region only" );
528
+
529
+ Block &afterBlock = whileOp.after ().front ();
530
+ if (!llvm::hasSingleElement (afterBlock))
531
+ return rewriter.notifyMatchFailure (whileOp,
532
+ " do-while simplification applicable "
533
+ " only if 'after' region has no payload" );
534
+
535
+ auto yield = dyn_cast<scf::YieldOp>(&afterBlock.front ());
536
+ if (!yield || yield.results () != afterBlock.getArguments ())
537
+ return rewriter.notifyMatchFailure (whileOp,
538
+ " do-while simplification applicable "
539
+ " only to forwarding 'after' regions" );
540
+
541
+ // Split the current block before the WhileOp to create the inlining point.
542
+ OpBuilder::InsertionGuard guard (rewriter);
543
+ Block *currentBlock = rewriter.getInsertionBlock ();
544
+ Block *continuation =
545
+ rewriter.splitBlock (currentBlock, rewriter.getInsertionPoint ());
546
+
547
+ // Only the "before" region should be inlined.
548
+ Block *before = &whileOp.before ().front ();
549
+ Block *beforeLast = &whileOp.before ().back ();
550
+ rewriter.inlineRegionBefore (whileOp.before (), continuation);
551
+
552
+ // Branch to the "before" region.
553
+ rewriter.setInsertionPointToEnd (currentBlock);
554
+ rewriter.create <BranchOp>(whileOp.getLoc (), before, whileOp.inits ());
555
+
556
+ // Loop around the "before" region based on condition.
557
+ rewriter.setInsertionPointToEnd (beforeLast);
558
+ auto condOp = cast<ConditionOp>(beforeLast->getTerminator ());
559
+ rewriter.replaceOpWithNewOp <CondBranchOp>(condOp, condOp.condition (), before,
560
+ condOp.args (), continuation,
561
+ ValueRange ());
562
+
563
+ // Replace the op with values "yielded" from the "before" region, which are
564
+ // visible by dominance.
565
+ rewriter.replaceOp (whileOp, condOp.args ());
566
+
567
+ return success ();
568
+ }
569
+
510
570
void mlir::populateLoopToStdConversionPatterns (
511
571
OwningRewritePatternList &patterns, MLIRContext *ctx) {
512
572
patterns.insert <ForLowering, IfLowering, ParallelLowering, WhileLowering>(
513
573
ctx);
574
+ patterns.insert <DoWhileLowering>(ctx, /* benefit=*/ 2 );
514
575
}
515
576
516
577
void SCFToStandardPass::runOnOperation () {
0 commit comments