Skip to content

Commit 4c0e255

Browse files
committed
[mlir] Add lowering to CFG for WhileOp
The lowering is a straightforward inlining of the "before" and "after" regions connected by (conditional) branches. This plugs the WhileOp into the progressive lowering scheme. Future commits may choose to target WhileOp instead of CFG when lowering ForOp. Differential Revision: https://reviews.llvm.org/D90603
1 parent 7971655 commit 4c0e255

File tree

2 files changed

+226
-4
lines changed

2 files changed

+226
-4
lines changed

mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp

Lines changed: 113 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,72 @@ struct ParallelLowering : public OpRewritePattern<mlir::scf::ParallelOp> {
200200
LogicalResult matchAndRewrite(mlir::scf::ParallelOp parallelOp,
201201
PatternRewriter &rewriter) const override;
202202
};
203+
204+
/// Create a CFG subgraph for this loop construct. The regions of the loop need
205+
/// not be a single block anymore (for example, if other SCF constructs that
206+
/// they contain have been already converted to CFG), but need to be single-exit
207+
/// from the last block of each region. The operations following the original
208+
/// WhileOp are split into a new continuation block. Both regions of the WhileOp
209+
/// are inlined, and their terminators are rewritten to organize the control
210+
/// flow implementing the loop as follows.
211+
///
212+
/// +---------------------------------+
213+
/// | <code before the WhileOp> |
214+
/// | br ^before(%operands...) |
215+
/// +---------------------------------+
216+
/// |
217+
/// -------| |
218+
/// | v v
219+
/// | +--------------------------------+
220+
/// | | ^before(%bargs...): |
221+
/// | | %vals... = <some payload> |
222+
/// | +--------------------------------+
223+
/// | |
224+
/// | ...
225+
/// | |
226+
/// | +--------------------------------+
227+
/// | | ^before-last:
228+
/// | | %cond = <compute condition> |
229+
/// | | cond_br %cond, |
230+
/// | | ^after(%vals...), ^cont |
231+
/// | +--------------------------------+
232+
/// | | |
233+
/// | | -------------|
234+
/// | v |
235+
/// | +--------------------------------+ |
236+
/// | | ^after(%aargs...): | |
237+
/// | | <body contents> | |
238+
/// | +--------------------------------+ |
239+
/// | | |
240+
/// | ... |
241+
/// | | |
242+
/// | +--------------------------------+ |
243+
/// | | ^after-last: | |
244+
/// | | %yields... = <some payload> | |
245+
/// | | br ^before(%yields...) | |
246+
/// | +--------------------------------+ |
247+
/// | | |
248+
/// |----------- |--------------------
249+
/// v
250+
/// +--------------------------------+
251+
/// | ^cont: |
252+
/// | <code after the WhileOp> |
253+
/// | <%vals from 'before' region |
254+
/// | visible by dominance> |
255+
/// +--------------------------------+
256+
///
257+
/// Values are communicated between ex-regions (the groups of blocks that used
258+
/// to form a region before inlining) through block arguments of their
259+
/// entry blocks, which are visible in all other dominated blocks. Similarly,
260+
/// the results of the WhileOp are defined in the 'before' region, which is
261+
/// required to have a single existing block, and are therefore accessible in
262+
/// the continuation block due to dominance.
263+
struct WhileLowering : public OpRewritePattern<WhileOp> {
264+
using OpRewritePattern<WhileOp>::OpRewritePattern;
265+
266+
LogicalResult matchAndRewrite(WhileOp whileOp,
267+
PatternRewriter &rewriter) const override;
268+
};
203269
} // namespace
204270

205271
LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
@@ -399,18 +465,61 @@ ParallelLowering::matchAndRewrite(ParallelOp parallelOp,
399465
return success();
400466
}
401467

468+
LogicalResult WhileLowering::matchAndRewrite(WhileOp whileOp,
469+
PatternRewriter &rewriter) const {
470+
OpBuilder::InsertionGuard guard(rewriter);
471+
Location loc = whileOp.getLoc();
472+
473+
// Split the current block before the WhileOp to create the inlining point.
474+
Block *currentBlock = rewriter.getInsertionBlock();
475+
Block *continuation =
476+
rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
477+
478+
// Inline both regions.
479+
Block *after = &whileOp.after().front();
480+
Block *afterLast = &whileOp.after().back();
481+
Block *before = &whileOp.before().front();
482+
Block *beforeLast = &whileOp.before().back();
483+
rewriter.inlineRegionBefore(whileOp.after(), continuation);
484+
rewriter.inlineRegionBefore(whileOp.before(), after);
485+
486+
// Branch to the "before" region.
487+
rewriter.setInsertionPointToEnd(currentBlock);
488+
rewriter.create<BranchOp>(loc, before, whileOp.inits());
489+
490+
// Replace terminators with branches. Assuming bodies are SESE, which holds
491+
// given only the patterns from this file, we only need to look at the last
492+
// block. This should be reconsidered if we allow break/continue in SCF.
493+
rewriter.setInsertionPointToEnd(beforeLast);
494+
auto condOp = cast<ConditionOp>(beforeLast->getTerminator());
495+
rewriter.replaceOpWithNewOp<CondBranchOp>(condOp, condOp.condition(), after,
496+
condOp.args(), continuation,
497+
ValueRange());
498+
499+
rewriter.setInsertionPointToEnd(afterLast);
500+
auto yieldOp = cast<scf::YieldOp>(afterLast->getTerminator());
501+
rewriter.replaceOpWithNewOp<BranchOp>(yieldOp, before, yieldOp.results());
502+
503+
// Replace the op with values "yielded" from the "before" region, which are
504+
// visible by dominance.
505+
rewriter.replaceOp(whileOp, condOp.args());
506+
507+
return success();
508+
}
509+
402510
void mlir::populateLoopToStdConversionPatterns(
403511
OwningRewritePatternList &patterns, MLIRContext *ctx) {
404-
patterns.insert<ForLowering, IfLowering, ParallelLowering>(ctx);
512+
patterns.insert<ForLowering, IfLowering, ParallelLowering, WhileLowering>(
513+
ctx);
405514
}
406515

407516
void SCFToStandardPass::runOnOperation() {
408517
OwningRewritePatternList patterns;
409518
populateLoopToStdConversionPatterns(patterns, &getContext());
410-
// Configure conversion to lower out scf.for, scf.if and scf.parallel.
411-
// Anything else is fine.
519+
// Configure conversion to lower out scf.for, scf.if, scf.parallel and
520+
// scf.while. Anything else is fine.
412521
ConversionTarget target(getContext());
413-
target.addIllegalOp<scf::ForOp, scf::IfOp, scf::ParallelOp>();
522+
target.addIllegalOp<scf::ForOp, scf::IfOp, scf::ParallelOp, scf::WhileOp>();
414523
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
415524
if (failed(
416525
applyPartialConversion(getOperation(), target, std::move(patterns))))

mlir/test/Conversion/SCFToStandard/convert-to-cfg.mlir

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,3 +412,116 @@ func @unknown_op_inside_loop(%arg0: index, %arg1: index, %arg2: index) {
412412
}
413413
return
414414
}
415+
416+
// CHECK-LABEL: @minimal_while
417+
func @minimal_while() {
418+
// CHECK: %[[COND:.*]] = "test.make_condition"() : () -> i1
419+
// CHECK: br ^[[BEFORE:.*]]
420+
%0 = "test.make_condition"() : () -> i1
421+
scf.while : () -> () {
422+
// CHECK: ^[[BEFORE]]:
423+
// CHECK: cond_br %[[COND]], ^[[AFTER:.*]], ^[[CONT:.*]]
424+
scf.condition(%0)
425+
} do {
426+
// CHECK: ^[[AFTER]]:
427+
// CHECK: br ^[[BEFORE]]
428+
scf.yield
429+
}
430+
// CHECK: ^[[CONT]]:
431+
// CHECK: return
432+
return
433+
}
434+
435+
// CHECK-LABEL: @while_values
436+
// CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: f32)
437+
func @while_values(%arg0: i32, %arg1: f32) {
438+
// CHECK: %[[COND:.*]] = "test.make_condition"() : () -> i1
439+
%0 = "test.make_condition"() : () -> i1
440+
%c0_i32 = constant 0 : i32
441+
%cst = constant 0.000000e+00 : f32
442+
// CHECK: br ^[[BEFORE:.*]](%[[ARG0]], %[[ARG1]] : i32, f32)
443+
%1:2 = scf.while (%arg2 = %arg0, %arg3 = %arg1) : (i32, f32) -> (i64, f64) {
444+
// CHECK: ^bb1(%[[ARG2:.*]]: i32, %[[ARG3:.]]: f32):
445+
// CHECK: %[[VAL1:.*]] = zexti %[[ARG0]] : i32 to i64
446+
%2 = zexti %arg0 : i32 to i64
447+
// CHECK: %[[VAL2:.*]] = fpext %[[ARG3]] : f32 to f64
448+
%3 = fpext %arg3 : f32 to f64
449+
// CHECK: cond_br %[[COND]],
450+
// CHECK: ^[[AFTER:.*]](%[[VAL1]], %[[VAL2]] : i64, f64),
451+
// CHECK: ^[[CONT:.*]]
452+
scf.condition(%0) %2, %3 : i64, f64
453+
} do {
454+
// CHECK: ^[[AFTER]](%[[ARG4:.*]]: i64, %[[ARG5:.*]]: f64):
455+
^bb0(%arg2: i64, %arg3: f64): // no predecessors
456+
// CHECK: br ^[[BEFORE]](%{{.*}}, %{{.*}} : i32, f32)
457+
scf.yield %c0_i32, %cst : i32, f32
458+
}
459+
// CHECK: ^bb3:
460+
// CHECK: return
461+
return
462+
}
463+
464+
// CHECK-LABEL: @nested_while_ops
465+
func @nested_while_ops(%arg0: f32) -> i64 {
466+
// CHECK: br ^[[OUTER_BEFORE:.*]](%{{.*}} : f32)
467+
%0 = scf.while(%outer = %arg0) : (f32) -> i64 {
468+
// CHECK: ^[[OUTER_BEFORE]](%{{.*}}: f32):
469+
// CHECK: %[[OUTER_COND:.*]] = "test.outer_before_pre"() : () -> i1
470+
%cond = "test.outer_before_pre"() : () -> i1
471+
// CHECK: br ^[[INNER_BEFORE_BEFORE:.*]](%{{.*}} : f32)
472+
%1 = scf.while(%inner = %outer) : (f32) -> i64 {
473+
// CHECK: ^[[INNER_BEFORE_BEFORE]](%{{.*}}: f32):
474+
// CHECK: %[[INNER1:.*]]:2 = "test.inner_before"(%{{.*}}) : (f32) -> (i1, i64)
475+
%2:2 = "test.inner_before"(%inner) : (f32) -> (i1, i64)
476+
// CHECK: cond_br %[[INNER1]]#0,
477+
// CHECK: ^[[INNER_BEFORE_AFTER:.*]](%[[INNER1]]#1 : i64),
478+
// CHECK: ^[[OUTER_BEFORE_LAST:.*]]
479+
scf.condition(%2#0) %2#1 : i64
480+
} do {
481+
// CHECK: ^[[INNER_BEFORE_AFTER]](%{{.*}}: i64):
482+
^bb0(%arg1: i64):
483+
// CHECK: %[[INNER2:.*]] = "test.inner_after"(%{{.*}}) : (i64) -> f32
484+
%3 = "test.inner_after"(%arg1) : (i64) -> f32
485+
// CHECK: br ^[[INNER_BEFORE_BEFORE]](%[[INNER2]] : f32)
486+
scf.yield %3 : f32
487+
}
488+
// CHECK: ^[[OUTER_BEFORE_LAST]]:
489+
// CHECK: "test.outer_before_post"() : () -> ()
490+
"test.outer_before_post"() : () -> ()
491+
// CHECK: cond_br %[[OUTER_COND]],
492+
// CHECK: ^[[OUTER_AFTER:.*]](%[[INNER1]]#1 : i64),
493+
// CHECK: ^[[CONTINUATION:.*]]
494+
scf.condition(%cond) %1 : i64
495+
} do {
496+
// CHECK: ^[[OUTER_AFTER]](%{{.*}}: i64):
497+
^bb2(%arg2: i64):
498+
// CHECK: "test.outer_after_pre"(%{{.*}}) : (i64) -> ()
499+
"test.outer_after_pre"(%arg2) : (i64) -> ()
500+
// CHECK: br ^[[INNER_AFTER_BEFORE:.*]](%{{.*}} : i64)
501+
%4 = scf.while(%inner = %arg2) : (i64) -> f32 {
502+
// CHECK: ^[[INNER_AFTER_BEFORE]](%{{.*}}: i64):
503+
// CHECK: %[[INNER3:.*]]:2 = "test.inner2_before"(%{{.*}}) : (i64) -> (i1, f32)
504+
%5:2 = "test.inner2_before"(%inner) : (i64) -> (i1, f32)
505+
// CHECK: cond_br %[[INNER3]]#0,
506+
// CHECK: ^[[INNER_AFTER_AFTER:.*]](%[[INNER3]]#1 : f32),
507+
// CHECK: ^[[OUTER_AFTER_LAST:.*]]
508+
scf.condition(%5#0) %5#1 : f32
509+
} do {
510+
// CHECK: ^[[INNER_AFTER_AFTER]](%{{.*}}: f32):
511+
^bb3(%arg3: f32):
512+
// CHECK: %{{.*}} = "test.inner2_after"(%{{.*}}) : (f32) -> i64
513+
%6 = "test.inner2_after"(%arg3) : (f32) -> i64
514+
// CHECK: br ^[[INNER_AFTER_BEFORE]](%{{.*}} : i64)
515+
scf.yield %6 : i64
516+
}
517+
// CHECK: ^[[OUTER_AFTER_LAST]]:
518+
// CHECK: "test.outer_after_post"() : () -> ()
519+
"test.outer_after_post"() : () -> ()
520+
// CHECK: br ^[[OUTER_BEFORE]](%[[INNER3]]#1 : f32)
521+
scf.yield %4 : f32
522+
}
523+
// CHECK: ^[[CONTINUATION]]:
524+
// CHECK: return %{{.*}} : i64
525+
return %0 : i64
526+
}
527+

0 commit comments

Comments
 (0)