Skip to content

Commit 3401122

Browse files
[CIR][ThroughMLIR] Lower ContinueOp nested inside IfOp (#1682)
As the scf dialect does not support early exits, it might be necessary to change the body of WhileOp to implement the semantics of ContinueOp. I choose to add a guard `if (!cond)` for everything following the `continue`. Co-authored-by: Yue Huang <[email protected]>
1 parent 849edca commit 3401122

File tree

2 files changed

+113
-8
lines changed

2 files changed

+113
-8
lines changed

clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRLoopToSCF.cpp

Lines changed: 69 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,62 @@ class CIRForOpLowering : public mlir::OpConversionPattern<cir::ForOp> {
414414
};
415415

416416
class CIRWhileOpLowering : public mlir::OpConversionPattern<cir::WhileOp> {
417+
void rewriteContinueInIf(cir::IfOp ifOp, cir::ContinueOp continueOp,
418+
mlir::scf::WhileOp whileOp,
419+
mlir::ConversionPatternRewriter &rewriter) const {
420+
auto loc = ifOp->getLoc();
421+
422+
rewriter.setInsertionPointToStart(whileOp.getAfterBody());
423+
auto boolTy = rewriter.getType<BoolType>();
424+
auto boolPtrTy = rewriter.getType<PointerType>(boolTy);
425+
auto alignment = rewriter.getI64IntegerAttr(4);
426+
auto condAlloca = rewriter.create<AllocaOp>(loc, boolPtrTy, boolTy,
427+
"condition", alignment);
428+
429+
rewriter.setInsertionPoint(ifOp);
430+
auto negated = rewriter.create<UnaryOp>(loc, boolTy, UnaryOpKind::Not,
431+
ifOp.getCondition());
432+
rewriter.create<StoreOp>(loc, negated, condAlloca);
433+
434+
// On each layer, surround everything after runner in its parent with a
435+
// guard: `if (!condAlloca)`.
436+
for (mlir::Operation *runner = ifOp; runner != whileOp;
437+
runner = runner->getParentOp()) {
438+
rewriter.setInsertionPointAfter(runner);
439+
auto cond = rewriter.create<LoadOp>(
440+
loc, boolTy, condAlloca, /*isDeref=*/false,
441+
/*volatile=*/false, /*nontemporal=*/false, alignment,
442+
/*memorder=*/cir::MemOrderAttr{}, /*tbaa=*/cir::TBAAAttr{});
443+
auto ifnot =
444+
rewriter.create<IfOp>(loc, cond, /*withElseRegion=*/false,
445+
[&](mlir::OpBuilder &, mlir::Location) {
446+
/* Intentionally left empty */
447+
});
448+
449+
auto &region = ifnot.getThenRegion();
450+
rewriter.setInsertionPointToEnd(&region.back());
451+
auto terminator = rewriter.create<YieldOp>(loc);
452+
453+
bool inserted = false;
454+
for (mlir::Operation *op = ifnot->getNextNode(); op;) {
455+
// Don't move terminators in.
456+
if (isa<YieldOp>(op) || isa<ReturnOp>(op))
457+
break;
458+
459+
mlir::Operation *next = op->getNextNode();
460+
op->moveBefore(terminator);
461+
op = next;
462+
inserted = true;
463+
}
464+
// Don't retain `if (!condAlloca)` when it's empty.
465+
if (!inserted)
466+
rewriter.eraseOp(ifnot);
467+
}
468+
rewriter.setInsertionPoint(continueOp);
469+
rewriter.create<mlir::scf::YieldOp>(continueOp->getLoc());
470+
rewriter.eraseOp(continueOp);
471+
}
472+
417473
void rewriteContinue(mlir::scf::WhileOp whileOp,
418474
mlir::ConversionPatternRewriter &rewriter) const {
419475
// Collect all ContinueOp inside this while.
@@ -427,23 +483,29 @@ class CIRWhileOpLowering : public mlir::OpConversionPattern<cir::WhileOp> {
427483
return;
428484

429485
for (auto continueOp : continues) {
430-
// When the break is under an IfOp, a direct replacement of `scf.yield`
431-
// won't work: the yield would jump out of that IfOp instead. We might
432-
// need to change the whileOp itself to achieve the same effect.
486+
// When the ContinueOp is under an IfOp, a direct replacement of
487+
// `scf.yield` won't work: the yield would jump out of that IfOp instead.
488+
// We might need to change the WhileOp itself to achieve the same effect.
489+
bool rewritten = false;
433490
for (mlir::Operation *parent = continueOp->getParentOp();
434491
parent != whileOp; parent = parent->getParentOp()) {
435-
if (isa<mlir::scf::IfOp>(parent) || isa<cir::IfOp>(parent))
436-
llvm_unreachable("NYI");
492+
if (auto ifOp = dyn_cast<cir::IfOp>(parent)) {
493+
rewriteContinueInIf(ifOp, continueOp, whileOp, rewriter);
494+
rewritten = true;
495+
break;
496+
}
437497
}
498+
if (rewritten)
499+
continue;
438500

439-
// Operations after this break has to be removed.
501+
// Operations after this ContinueOp has to be removed.
440502
for (mlir::Operation *runner = continueOp->getNextNode(); runner;) {
441503
mlir::Operation *next = runner->getNextNode();
442504
runner->erase();
443505
runner = next;
444506
}
445507

446-
// Blocks after this break also has to be removed.
508+
// Blocks after this ContinueOp also has to be removed.
447509
for (mlir::Block *block = continueOp->getBlock()->getNextNode(); block;) {
448510
mlir::Block *next = block->getNextNode();
449511
block->erase();

clang/test/CIR/Lowering/ThroughMLIR/while-with-continue.cpp

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -fno-clangir-direct-lowering -emit-mlir=core %s -o %t.mlir
22
// RUN: FileCheck --input-file=%t.mlir %s
33

4-
void for_with_break() {
4+
void while_continue() {
55
int i = 0;
66
while (i < 100) {
77
i++;
@@ -25,3 +25,46 @@ void for_with_break() {
2525
// CHECK: scf.yield
2626
// CHECK: }
2727
}
28+
29+
void while_continue_2() {
30+
int i = 0;
31+
while (i < 10) {
32+
if (i == 5) {
33+
i += 3;
34+
continue;
35+
}
36+
37+
i++;
38+
}
39+
// The final i++ will have a `if (!(i == 5))` guarded against it.
40+
41+
// CHECK: do {
42+
// CHECK: %[[NOTALLOCA:.+]] = memref.alloca
43+
// CHECK: memref.alloca_scope {
44+
// CHECK: memref.alloca_scope {
45+
// CHECK: %[[IV:.+]] = memref.load %[[IVADDR:.+]][]
46+
// CHECK: %[[FIVE:.+]] = arith.constant 5
47+
// CHECK: %[[COND:.+]] = arith.cmpi eq, %[[IV]], %[[FIVE]]
48+
// CHECK: %true = arith.constant true
49+
// CHECK: %[[NOT:.+]] = arith.xori %true, %[[COND]]
50+
// CHECK: %[[EXT:.+]] = arith.extui %[[NOT]] : i1 to i8
51+
// CHECK: memref.store %[[EXT]], %[[NOTALLOCA]]
52+
// CHECK: scf.if %[[COND]] {
53+
// CHECK: %[[THREE:.+]] = arith.constant 3
54+
// CHECK: %[[IV2:.+]] = memref.load %[[IVADDR]]
55+
// CHECK: %[[TMP:.+]] = arith.addi %[[IV2]], %[[THREE]]
56+
// CHECK: memref.store %[[TMP]], %[[IVADDR]]
57+
// CHECK: }
58+
// CHECK: }
59+
// CHECK: %[[NOTCOND:.+]] = memref.load %[[NOTALLOCA]]
60+
// CHECK: %[[TRUNC:.+]] = arith.trunci %[[NOTCOND]] : i8 to i1
61+
// CHECK: scf.if %[[TRUNC]] {
62+
// CHECK: %[[IV3:.+]] = memref.load %[[IVADDR]]
63+
// CHECK: %[[ONE:.+]] = arith.constant 1
64+
// CHECK: %[[TMP2:.+]] = arith.addi %[[IV3]], %[[ONE]]
65+
// CHECK: memref.store %[[TMP2]], %[[IVADDR]]
66+
// CHECK: }
67+
// CHECK: }
68+
// CHECK: scf.yield
69+
// CHECK: }
70+
}

0 commit comments

Comments
 (0)