Skip to content

[CIR] Upstream support for break and continue statements #134181

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,16 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
return create<cir::ForOp>(loc, condBuilder, bodyBuilder, stepBuilder);
}

/// Create a break operation.
cir::BreakOp createBreak(mlir::Location loc) {
return create<cir::BreakOp>(loc);
}

/// Create a continue operation.
cir::ContinueOp createContinue(mlir::Location loc) {
return create<cir::ContinueOp>(loc);
}

mlir::TypedAttr getConstPtrAttr(mlir::Type type, int64_t value) {
auto valueAttr = mlir::IntegerAttr::get(
mlir::IntegerType::get(type.getContext(), 64), value);
Expand Down
44 changes: 37 additions & 7 deletions clang/include/clang/CIR/Dialect/IR/CIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -477,20 +477,20 @@ def ConditionOp : CIR_Op<"condition", [
`cir.bool` operand and, depending on its value, may branch to different
regions:

- When in the `cond` region of a `cir.loop`, it continues the loop
- When in the `cond` region of a loop, it continues the loop
if true, or exits it if false.
- When in the `ready` region of a `cir.await`, it branches to the `resume`
region when true, and to the `suspend` region when false.

Example:

```mlir
cir.loop for(cond : {
cir.condition(%arg0) // Branches to `step` region or exits.
}, step : {
[...]
}) {
[...]
cir.for cond {
cir.condition(%val) // Branches to `step` region or exits.
} body {
cir.yield
} step {
cir.yield
}

cir.await(user, ready : {
Expand Down Expand Up @@ -569,6 +569,36 @@ def YieldOp : CIR_Op<"yield", [ReturnLike, Terminator,
];
}

//===----------------------------------------------------------------------===//
// BreakOp
//===----------------------------------------------------------------------===//

def BreakOp : CIR_Op<"break", [Terminator]> {
let summary = "C/C++ `break` statement equivalent";
let description = [{
The `cir.break` operation is used to cease the execution of the current loop
or switch operation and transfer control to the parent operation. It is only
allowed within a breakable operations (loops and switches).
}];
let assemblyFormat = "attr-dict";
let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
// ContinueOp
//===----------------------------------------------------------------------===//

def ContinueOp : CIR_Op<"continue", [Terminator]> {
let summary = "C/C++ `continue` statement equivalent";
let description = [{
The `cir.continue` operation is used to end execution of the current
iteration of a loop and resume execution beginning at the next iteration.
It is only allowed within loop regions.
}];
let assemblyFormat = "attr-dict";
let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
// ScopeOp
//===----------------------------------------------------------------------===//
Expand Down
2 changes: 0 additions & 2 deletions clang/include/clang/CIR/MissingFeatures.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,12 +122,10 @@ struct MissingFeatures {

// Future CIR operations
static bool awaitOp() { return false; }
static bool breakOp() { return false; }
static bool callOp() { return false; }
static bool complexCreateOp() { return false; }
static bool complexImagOp() { return false; }
static bool complexRealOp() { return false; }
static bool continueOp() { return false; }
static bool ifOp() { return false; }
static bool labelOp() { return false; }
static bool ptrDiffOp() { return false; }
Expand Down
2 changes: 2 additions & 0 deletions clang/lib/CIR/CodeGen/CIRGenFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,8 @@ class CIRGenFunction : public CIRGenTypeCache {

LValue emitBinaryOperatorLValue(const BinaryOperator *e);

mlir::LogicalResult emitBreakStmt(const clang::BreakStmt &s);
mlir::LogicalResult emitContinueStmt(const clang::ContinueStmt &s);
mlir::LogicalResult emitDoStmt(const clang::DoStmt &s);

/// Emit an expression as an initializer for an object (variable, field, etc.)
Expand Down
35 changes: 29 additions & 6 deletions clang/lib/CIR/CodeGen/CIRGenStmt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,12 @@ mlir::LogicalResult CIRGenFunction::emitStmt(const Stmt *s,
return mlir::success();

switch (s->getStmtClass()) {
case Stmt::BreakStmtClass:
case Stmt::CompoundStmtClass:
case Stmt::ContinueStmtClass:
case Stmt::DeclStmtClass:
case Stmt::ReturnStmtClass:
llvm_unreachable("should have emitted these statements as simple");

#define STMT(Type, Base)
#define ABSTRACT_STMT(Op)
Expand Down Expand Up @@ -88,13 +94,9 @@ mlir::LogicalResult CIRGenFunction::emitStmt(const Stmt *s,
case Stmt::SEHFinallyStmtClass:
case Stmt::MSDependentExistsStmtClass:
case Stmt::NullStmtClass:
case Stmt::CompoundStmtClass:
case Stmt::DeclStmtClass:
case Stmt::LabelStmtClass:
case Stmt::AttributedStmtClass:
case Stmt::GotoStmtClass:
case Stmt::BreakStmtClass:
case Stmt::ContinueStmtClass:
case Stmt::DefaultStmtClass:
case Stmt::CaseStmtClass:
case Stmt::SEHLeaveStmtClass:
Expand All @@ -106,7 +108,6 @@ mlir::LogicalResult CIRGenFunction::emitStmt(const Stmt *s,
case Stmt::CXXTryStmtClass:
case Stmt::CXXForRangeStmtClass:
case Stmt::IndirectGotoStmtClass:
case Stmt::ReturnStmtClass:
case Stmt::GCCAsmStmtClass:
case Stmt::MSAsmStmtClass:
case Stmt::OMPParallelDirectiveClass:
Expand Down Expand Up @@ -219,7 +220,6 @@ mlir::LogicalResult CIRGenFunction::emitSimpleStmt(const Stmt *s,
bool useCurrentScope) {
switch (s->getStmtClass()) {
default:
// Only compound and return statements are supported right now.
return mlir::failure();
case Stmt::DeclStmtClass:
return emitDeclStmt(cast<DeclStmt>(*s));
Expand All @@ -229,6 +229,10 @@ mlir::LogicalResult CIRGenFunction::emitSimpleStmt(const Stmt *s,
else
emitCompoundStmt(cast<CompoundStmt>(*s));
break;
case Stmt::ContinueStmtClass:
return emitContinueStmt(cast<ContinueStmt>(*s));
case Stmt::BreakStmtClass:
return emitBreakStmt(cast<BreakStmt>(*s));
case Stmt::ReturnStmtClass:
return emitReturnStmt(cast<ReturnStmt>(*s));
}
Expand Down Expand Up @@ -316,6 +320,25 @@ mlir::LogicalResult CIRGenFunction::emitReturnStmt(const ReturnStmt &s) {
return mlir::success();
}

mlir::LogicalResult
CIRGenFunction::emitContinueStmt(const clang::ContinueStmt &s) {
builder.createContinue(getLoc(s.getContinueLoc()));

// Insert the new block to continue codegen after the continue statement.
builder.createBlock(builder.getBlock()->getParent());

return mlir::success();
}

mlir::LogicalResult CIRGenFunction::emitBreakStmt(const clang::BreakStmt &s) {
builder.createBreak(getLoc(s.getBreakLoc()));

// Insert the new block to continue codegen after the break statement.
builder.createBlock(builder.getBlock()->getParent());

return mlir::success();
}

mlir::LogicalResult CIRGenFunction::emitForStmt(const ForStmt &s) {
cir::ForOp forOp;

Expand Down
21 changes: 21 additions & 0 deletions clang/lib/CIR/Dialect/IR/CIRDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,17 @@ void cir::AllocaOp::build(mlir::OpBuilder &odsBuilder,
odsState.addTypes(addr);
}

//===----------------------------------------------------------------------===//
// BreakOp
//===----------------------------------------------------------------------===//

LogicalResult cir::BreakOp::verify() {
assert(!cir::MissingFeatures::switchOp());
if (!getOperation()->getParentOfType<LoopOpInterface>())
return emitOpError("must be within a loop");
return success();
}

//===----------------------------------------------------------------------===//
// ConditionOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -241,6 +252,16 @@ OpFoldResult cir::ConstantOp::fold(FoldAdaptor /*adaptor*/) {
return getValue();
}

//===----------------------------------------------------------------------===//
// ContinueOp
//===----------------------------------------------------------------------===//

LogicalResult cir::ContinueOp::verify() {
if (!getOperation()->getParentOfType<LoopOpInterface>())
return emitOpError("must be within a loop");
return success();
}

//===----------------------------------------------------------------------===//
// CastOp
//===----------------------------------------------------------------------===//
Expand Down
21 changes: 11 additions & 10 deletions clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,23 +148,24 @@ class CIRLoopOpInterfaceFlattening
// driver to customize the order that operations are visited.

// Lower continue statements.
mlir::Block *dest = (step ? step : cond);
op.walkBodySkippingNestedLoops([&](mlir::Operation *op) {
// When continue ops are supported, there will be a check for them here
// and a call to lowerTerminator(). The call to `advance()` handles the
// case where this is not a continue op.
assert(!cir::MissingFeatures::continueOp());
return mlir::WalkResult::advance();
if (!isa<cir::ContinueOp>(op))
return mlir::WalkResult::advance();

lowerTerminator(op, dest, rewriter);
return mlir::WalkResult::skip();
});

// Lower break statements.
assert(!cir::MissingFeatures::switchOp());
walkRegionSkipping<cir::LoopOpInterface>(
op.getBody(), [&](mlir::Operation *op) {
// When break ops are supported, there will be a check for them here
// and a call to lowerTerminator(). The call to `advance()` handles
// the case where this is not a break op.
assert(!cir::MissingFeatures::breakOp());
return mlir::WalkResult::advance();
if (!isa<cir::BreakOp>(op))
return mlir::WalkResult::advance();

lowerTerminator(op, exit, rewriter);
return mlir::WalkResult::skip();
});

// Lower optional body region yield.
Expand Down
122 changes: 122 additions & 0 deletions clang/test/CIR/CodeGen/loop.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -265,3 +265,125 @@ void test_empty_while_true() {
// OGCG: br label %[[WHILE_BODY:.*]]
// OGCG: [[WHILE_BODY]]:
// OGCG: ret void

void unreachable_after_continue() {
for (;;) {
continue;
int x = 1;
}
}

// CIR: cir.func @unreachable_after_continue
// CIR: cir.scope {
// CIR: cir.for : cond {
// CIR: %[[TRUE:.*]] = cir.const #true
// CIR: cir.condition(%[[TRUE]])
// CIR: } body {
// CIR: cir.scope {
// CIR: %[[X:.*]] = cir.alloca !s32i, !cir.ptr<!s32i>, ["x", init] {alignment = 4 : i64}
// CIR: cir.continue
// CIR: ^bb1: // no predecessors
// CIR: %[[ONE:.*]] = cir.const #cir.int<1> : !s32i
// CIR: cir.store %[[ONE]], %[[X]] : !s32i, !cir.ptr<!s32i>
// CIR: cir.yield
// CIR: }
// CIR: cir.yield
// CIR: } step {
// CIR: cir.yield
// CIR: }
// CIR: }
// CIR: cir.return
// CIR: }

// LLVM: define void @unreachable_after_continue()
// LLVM: %[[X:.*]] = alloca i32, i64 1, align 4
// LLVM: br label %[[LABEL1:.*]]
// LLVM: [[LABEL1]]:
// LLVM: br label %[[LABEL2:.*]]
// LLVM: [[LABEL2]]:
// LLVM: br i1 true, label %[[LABEL3:.*]], label %[[LABEL8:.*]]
// LLVM: [[LABEL3]]:
// LLVM: br label %[[LABEL4:.*]]
// LLVM: [[LABEL4]]:
// LLVM: br label %[[LABEL7:.*]]
// LLVM: [[LABEL5:.*]]:
// LLVM-SAME: ; No predecessors!
// LLVM: store i32 1, ptr %[[X]], align 4
// LLVM: br label %[[LABEL6:.*]]
// LLVM: [[LABEL6]]:
// LLVM: br label %[[LABEL7:.*]]
// LLVM: [[LABEL7]]:
// LLVM: br label %[[LABEL2]]
// LLVM: [[LABEL8]]:
// LLVM: br label %[[LABEL9:]]
// LLVM: [[LABEL9]]:
// LLVM: ret void

// OGCG: define{{.*}} void @_Z26unreachable_after_continuev()
// OGCG: entry:
// OGCG: %[[X:.*]] = alloca i32, align 4
// OGCG: br label %[[FOR_COND:.*]]
// OGCG: [[FOR_COND]]:
// OGCG: br label %[[FOR_COND]]

void unreachable_after_break() {
for (;;) {
break;
int x = 1;
}
}

// CIR: cir.func @unreachable_after_break
// CIR: cir.scope {
// CIR: cir.for : cond {
// CIR: %[[TRUE:.*]] = cir.const #true
// CIR: cir.condition(%[[TRUE]])
// CIR: } body {
// CIR: cir.scope {
// CIR: %[[X:.*]] = cir.alloca !s32i, !cir.ptr<!s32i>, ["x", init] {alignment = 4 : i64}
// CIR: cir.break
// CIR: ^bb1: // no predecessors
// CIR: %[[ONE:.*]] = cir.const #cir.int<1> : !s32i
// CIR: cir.store %[[ONE]], %[[X]] : !s32i, !cir.ptr<!s32i>
// CIR: cir.yield
// CIR: }
// CIR: cir.yield
// CIR: } step {
// CIR: cir.yield
// CIR: }
// CIR: }
// CIR: cir.return
// CIR: }

// LLVM: define void @unreachable_after_break()
// LLVM: %[[X:.*]] = alloca i32, i64 1, align 4
// LLVM: br label %[[LABEL1:.*]]
// LLVM: [[LABEL1]]:
// LLVM: br label %[[LABEL2:.*]]
// LLVM: [[LABEL2]]:
// LLVM: br i1 true, label %[[LABEL3:.*]], label %[[LABEL8:.*]]
// LLVM: [[LABEL3]]:
// LLVM: br label %[[LABEL4:.*]]
// LLVM: [[LABEL4]]:
// LLVM: br label %[[LABEL8]]
// LLVM: [[LABEL5:.*]]:
// LLVM-SAME: ; No predecessors!
// LLVM: store i32 1, ptr %[[X]], align 4
// LLVM: br label %[[LABEL6:.*]]
// LLVM: [[LABEL6]]:
// LLVM: br label %[[LABEL7:.*]]
// LLVM: [[LABEL7]]:
// LLVM: br label %[[LABEL2]]
// LLVM: [[LABEL8]]:
// LLVM: br label %[[LABEL9:]]
// LLVM: [[LABEL9]]:
// LLVM: ret void

// OGCG: define{{.*}} void @_Z23unreachable_after_breakv()
// OGCG: entry:
// OGCG: %[[X:.*]] = alloca i32, align 4
// OGCG: br label %[[FOR_COND:.*]]
// OGCG: [[FOR_COND]]:
// OGCG: br label %[[FOR_END:.*]]
// OGCG: [[FOR_END]]:
// OGCG: ret void