@@ -200,6 +200,72 @@ struct ParallelLowering : public OpRewritePattern<mlir::scf::ParallelOp> {
200
200
LogicalResult matchAndRewrite (mlir::scf::ParallelOp parallelOp,
201
201
PatternRewriter &rewriter) const override ;
202
202
};
203
+
204
+ // / Create a CFG subgraph for this loop construct. The regions of the loop need
205
+ // / not be a single block anymore (for example, if other SCF constructs that
206
+ // / they contain have been already converted to CFG), but need to be single-exit
207
+ // / from the last block of each region. The operations following the original
208
+ // / WhileOp are split into a new continuation block. Both regions of the WhileOp
209
+ // / are inlined, and their terminators are rewritten to organize the control
210
+ // / flow implementing the loop as follows.
211
+ // /
212
+ // / +---------------------------------+
213
+ // / | <code before the WhileOp> |
214
+ // / | br ^before(%operands...) |
215
+ // / +---------------------------------+
216
+ // / |
217
+ // / -------| |
218
+ // / | v v
219
+ // / | +--------------------------------+
220
+ // / | | ^before(%bargs...): |
221
+ // / | | %vals... = <some payload> |
222
+ // / | +--------------------------------+
223
+ // / | |
224
+ // / | ...
225
+ // / | |
226
+ // / | +--------------------------------+
227
+ // / | | ^before-last:
228
+ // / | | %cond = <compute condition> |
229
+ // / | | cond_br %cond, |
230
+ // / | | ^after(%vals...), ^cont |
231
+ // / | +--------------------------------+
232
+ // / | | |
233
+ // / | | -------------|
234
+ // / | v |
235
+ // / | +--------------------------------+ |
236
+ // / | | ^after(%aargs...): | |
237
+ // / | | <body contents> | |
238
+ // / | +--------------------------------+ |
239
+ // / | | |
240
+ // / | ... |
241
+ // / | | |
242
+ // / | +--------------------------------+ |
243
+ // / | | ^after-last: | |
244
+ // / | | %yields... = <some payload> | |
245
+ // / | | br ^before(%yields...) | |
246
+ // / | +--------------------------------+ |
247
+ // / | | |
248
+ // / |----------- |--------------------
249
+ // / v
250
+ // / +--------------------------------+
251
+ // / | ^cont: |
252
+ // / | <code after the WhileOp> |
253
+ // / | <%vals from 'before' region |
254
+ // / | visible by dominance> |
255
+ // / +--------------------------------+
256
+ // /
257
+ // / Values are communicated between ex-regions (the groups of blocks that used
258
+ // / to form a region before inlining) through block arguments of their
259
+ // / entry blocks, which are visible in all other dominated blocks. Similarly,
260
+ // / the results of the WhileOp are defined in the 'before' region, which is
261
+ // / required to have a single existing block, and are therefore accessible in
262
+ // / the continuation block due to dominance.
263
+ struct WhileLowering : public OpRewritePattern <WhileOp> {
264
+ using OpRewritePattern<WhileOp>::OpRewritePattern;
265
+
266
+ LogicalResult matchAndRewrite (WhileOp whileOp,
267
+ PatternRewriter &rewriter) const override ;
268
+ };
203
269
} // namespace
204
270
205
271
LogicalResult ForLowering::matchAndRewrite (ForOp forOp,
@@ -399,18 +465,61 @@ ParallelLowering::matchAndRewrite(ParallelOp parallelOp,
399
465
return success ();
400
466
}
401
467
468
+ LogicalResult WhileLowering::matchAndRewrite (WhileOp whileOp,
469
+ PatternRewriter &rewriter) const {
470
+ OpBuilder::InsertionGuard guard (rewriter);
471
+ Location loc = whileOp.getLoc ();
472
+
473
+ // Split the current block before the WhileOp to create the inlining point.
474
+ Block *currentBlock = rewriter.getInsertionBlock ();
475
+ Block *continuation =
476
+ rewriter.splitBlock (currentBlock, rewriter.getInsertionPoint ());
477
+
478
+ // Inline both regions.
479
+ Block *after = &whileOp.after ().front ();
480
+ Block *afterLast = &whileOp.after ().back ();
481
+ Block *before = &whileOp.before ().front ();
482
+ Block *beforeLast = &whileOp.before ().back ();
483
+ rewriter.inlineRegionBefore (whileOp.after (), continuation);
484
+ rewriter.inlineRegionBefore (whileOp.before (), after);
485
+
486
+ // Branch to the "before" region.
487
+ rewriter.setInsertionPointToEnd (currentBlock);
488
+ rewriter.create <BranchOp>(loc, before, whileOp.inits ());
489
+
490
+ // Replace terminators with branches. Assuming bodies are SESE, which holds
491
+ // given only the patterns from this file, we only need to look at the last
492
+ // block. This should be reconsidered if we allow break/continue in SCF.
493
+ rewriter.setInsertionPointToEnd (beforeLast);
494
+ auto condOp = cast<ConditionOp>(beforeLast->getTerminator ());
495
+ rewriter.replaceOpWithNewOp <CondBranchOp>(condOp, condOp.condition (), after,
496
+ condOp.args (), continuation,
497
+ ValueRange ());
498
+
499
+ rewriter.setInsertionPointToEnd (afterLast);
500
+ auto yieldOp = cast<scf::YieldOp>(afterLast->getTerminator ());
501
+ rewriter.replaceOpWithNewOp <BranchOp>(yieldOp, before, yieldOp.results ());
502
+
503
+ // Replace the op with values "yielded" from the "before" region, which are
504
+ // visible by dominance.
505
+ rewriter.replaceOp (whileOp, condOp.args ());
506
+
507
+ return success ();
508
+ }
509
+
402
510
void mlir::populateLoopToStdConversionPatterns (
403
511
OwningRewritePatternList &patterns, MLIRContext *ctx) {
404
- patterns.insert <ForLowering, IfLowering, ParallelLowering>(ctx);
512
+ patterns.insert <ForLowering, IfLowering, ParallelLowering, WhileLowering>(
513
+ ctx);
405
514
}
406
515
407
516
void SCFToStandardPass::runOnOperation () {
408
517
OwningRewritePatternList patterns;
409
518
populateLoopToStdConversionPatterns (patterns, &getContext ());
410
- // Configure conversion to lower out scf.for, scf.if and scf.parallel.
411
- // Anything else is fine.
519
+ // Configure conversion to lower out scf.for, scf.if, scf.parallel and
520
+ // scf.while. Anything else is fine.
412
521
ConversionTarget target (getContext ());
413
- target.addIllegalOp <scf::ForOp, scf::IfOp, scf::ParallelOp>();
522
+ target.addIllegalOp <scf::ForOp, scf::IfOp, scf::ParallelOp, scf::WhileOp >();
414
523
target.markUnknownOpDynamicallyLegal ([](Operation *) { return true ; });
415
524
if (failed (
416
525
applyPartialConversion (getOperation (), target, std::move (patterns))))
0 commit comments