Skip to content

[CIR] Upstream support for while and do..while loops #133157

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,22 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
return cir::BoolAttr::get(getContext(), getBoolTy(), state);
}

/// Create a do-while operation.
cir::DoWhileOp createDoWhile(
mlir::Location loc,
llvm::function_ref<void(mlir::OpBuilder &, mlir::Location)> condBuilder,
llvm::function_ref<void(mlir::OpBuilder &, mlir::Location)> bodyBuilder) {
return create<cir::DoWhileOp>(loc, condBuilder, bodyBuilder);
}

/// Create a while operation.
cir::WhileOp createWhile(
mlir::Location loc,
llvm::function_ref<void(mlir::OpBuilder &, mlir::Location)> condBuilder,
llvm::function_ref<void(mlir::OpBuilder &, mlir::Location)> bodyBuilder) {
return create<cir::WhileOp>(loc, condBuilder, bodyBuilder);
}

/// Create a for operation.
cir::ForOp createFor(
mlir::Location loc,
Expand Down
3 changes: 3 additions & 0 deletions clang/include/clang/CIR/Dialect/IR/CIRDialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@
#include "clang/CIR/Interfaces/CIRLoopOpInterface.h"
#include "clang/CIR/Interfaces/CIROpInterfaces.h"

using BuilderCallbackRef =
llvm::function_ref<void(mlir::OpBuilder &, mlir::Location)>;

// TableGen'erated files for MLIR dialects require that a macro be defined when
// they are included. GET_OP_CLASSES tells the file to define the classes for
// the operations of that dialect.
Expand Down
103 changes: 100 additions & 3 deletions clang/include/clang/CIR/Dialect/IR/CIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,8 @@ def StoreOp : CIR_Op<"store", [
// ReturnOp
//===----------------------------------------------------------------------===//

def ReturnOp : CIR_Op<"return", [ParentOneOf<["FuncOp", "ScopeOp", "ForOp"]>,
def ReturnOp : CIR_Op<"return", [ParentOneOf<["FuncOp", "ScopeOp", "DoWhileOp",
"WhileOp", "ForOp"]>,
Terminator]> {
let summary = "Return from function";
let description = [{
Expand Down Expand Up @@ -511,7 +512,8 @@ def ConditionOp : CIR_Op<"condition", [
//===----------------------------------------------------------------------===//

def YieldOp : CIR_Op<"yield", [ReturnLike, Terminator,
ParentOneOf<["ScopeOp", "ForOp"]>]> {
ParentOneOf<["ScopeOp", "WhileOp", "ForOp",
"DoWhileOp"]>]> {
let summary = "Represents the default branching behaviour of a region";
let description = [{
The `cir.yield` operation terminates regions on different CIR operations,
Expand Down Expand Up @@ -759,11 +761,106 @@ def BrCondOp : CIR_Op<"brcond",
}];
}

//===----------------------------------------------------------------------===//
// Common loop op definitions
//===----------------------------------------------------------------------===//

class LoopOpBase<string mnemonic> : CIR_Op<mnemonic, [
LoopOpInterface,
NoRegionArguments,
]> {
let extraClassDefinition = [{
void $cppClass::getSuccessorRegions(
mlir::RegionBranchPoint point,
llvm::SmallVectorImpl<mlir::RegionSuccessor> &regions) {
LoopOpInterface::getLoopOpSuccessorRegions(*this, point, regions);
}
llvm::SmallVector<Region *> $cppClass::getLoopRegions() {
return {&getBody()};
}
}];
}

//===----------------------------------------------------------------------===//
// While & DoWhileOp
//===----------------------------------------------------------------------===//

class WhileOpBase<string mnemonic> : LoopOpBase<mnemonic> {
defvar isWhile = !eq(mnemonic, "while");
let summary = "C/C++ " # !if(isWhile, "while", "do-while") # " loop";
let builders = [
OpBuilder<(ins "BuilderCallbackRef":$condBuilder,
"BuilderCallbackRef":$bodyBuilder), [{
mlir::OpBuilder::InsertionGuard guard($_builder);
$_builder.createBlock($_state.addRegion());
}] # !if(isWhile, [{
condBuilder($_builder, $_state.location);
$_builder.createBlock($_state.addRegion());
bodyBuilder($_builder, $_state.location);
}], [{
bodyBuilder($_builder, $_state.location);
$_builder.createBlock($_state.addRegion());
condBuilder($_builder, $_state.location);
}])>
];
}

def WhileOp : WhileOpBase<"while"> {
let regions = (region SizedRegion<1>:$cond, MinSizedRegion<1>:$body);
let assemblyFormat = "$cond `do` $body attr-dict";

let description = [{
Represents a C/C++ while loop. It consists of two regions:

- `cond`: single block region with the loop's condition. Should be
terminated with a `cir.condition` operation.
- `body`: contains the loop body and an arbitrary number of blocks.

Example:

```mlir
cir.while {
cir.break
^bb2:
cir.yield
} do {
cir.condition %cond : cir.bool
}
```
}];
}

def DoWhileOp : WhileOpBase<"do"> {
let regions = (region MinSizedRegion<1>:$body, SizedRegion<1>:$cond);
let assemblyFormat = " $body `while` $cond attr-dict";

let extraClassDeclaration = [{
mlir::Region &getEntry() { return getBody(); }
}];

let description = [{
Represents a C/C++ do-while loop. Identical to `cir.while` but the
condition is evaluated after the body.

Example:

```mlir
cir.do {
cir.break
^bb2:
cir.yield
} while {
cir.condition %cond : cir.bool
}
```
}];
}

//===----------------------------------------------------------------------===//
// ForOp
//===----------------------------------------------------------------------===//

def ForOp : CIR_Op<"for", [LoopOpInterface, NoRegionArguments]> {
def ForOp : LoopOpBase<"for"> {
let summary = "C/C++ for loop counterpart";
let description = [{
Represents a C/C++ for loop. It consists of three regions:
Expand Down
4 changes: 4 additions & 0 deletions clang/lib/CIR/CodeGen/CIRGenFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,8 @@ class CIRGenFunction : public CIRGenTypeCache {

LValue emitBinaryOperatorLValue(const BinaryOperator *e);

mlir::LogicalResult emitDoStmt(const clang::DoStmt &s);

/// Emit an expression as an initializer for an object (variable, field, etc.)
/// at the given location. The expression is not necessarily the normal
/// initializer for the object, and the address is not necessarily
Expand Down Expand Up @@ -497,6 +499,8 @@ class CIRGenFunction : public CIRGenTypeCache {
/// inside a function, including static vars etc.
void emitVarDecl(const clang::VarDecl &d);

mlir::LogicalResult emitWhileStmt(const clang::WhileStmt &s);

/// ----------------------
/// CIR build helpers
/// -----------------
Expand Down
113 changes: 111 additions & 2 deletions clang/lib/CIR/CodeGen/CIRGenStmt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@ mlir::LogicalResult CIRGenFunction::emitStmt(const Stmt *s,

case Stmt::ForStmtClass:
return emitForStmt(cast<ForStmt>(*s));
case Stmt::WhileStmtClass:
return emitWhileStmt(cast<WhileStmt>(*s));
case Stmt::DoStmtClass:
return emitDoStmt(cast<DoStmt>(*s));

case Stmt::OMPScopeDirectiveClass:
case Stmt::OMPErrorDirectiveClass:
Expand All @@ -97,8 +101,6 @@ mlir::LogicalResult CIRGenFunction::emitStmt(const Stmt *s,
case Stmt::SYCLKernelCallStmtClass:
case Stmt::IfStmtClass:
case Stmt::SwitchStmtClass:
case Stmt::WhileStmtClass:
case Stmt::DoStmtClass:
case Stmt::CoroutineBodyStmtClass:
case Stmt::CoreturnStmtClass:
case Stmt::CXXTryStmtClass:
Expand Down Expand Up @@ -387,3 +389,110 @@ mlir::LogicalResult CIRGenFunction::emitForStmt(const ForStmt &s) {
terminateBody(builder, forOp.getBody(), getLoc(s.getEndLoc()));
return mlir::success();
}

mlir::LogicalResult CIRGenFunction::emitDoStmt(const DoStmt &s) {
cir::DoWhileOp doWhileOp;

// TODO: pass in array of attributes.
auto doStmtBuilder = [&]() -> mlir::LogicalResult {
mlir::LogicalResult loopRes = mlir::success();
assert(!cir::MissingFeatures::loopInfoStack());
// From LLVM: if there are any cleanups between here and the loop-exit
// scope, create a block to stage a loop exit along.
// We probably already do the right thing because of ScopeOp, but make
// sure we handle all cases.
assert(!cir::MissingFeatures::requiresCleanups());

doWhileOp = builder.createDoWhile(
getLoc(s.getSourceRange()),
/*condBuilder=*/
[&](mlir::OpBuilder &b, mlir::Location loc) {
assert(!cir::MissingFeatures::createProfileWeightsForLoop());
assert(!cir::MissingFeatures::emitCondLikelihoodViaExpectIntrinsic());
// C99 6.8.5p2/p4: The first substatement is executed if the
// expression compares unequal to 0. The condition must be a
// scalar type.
mlir::Value condVal = evaluateExprAsBool(s.getCond());
builder.createCondition(condVal);
},
/*bodyBuilder=*/
[&](mlir::OpBuilder &b, mlir::Location loc) {
// The scope of the do-while loop body is a nested scope.
if (emitStmt(s.getBody(), /*useCurrentScope=*/false).failed())
loopRes = mlir::failure();
emitStopPoint(&s);
});
return loopRes;
};

mlir::LogicalResult res = mlir::success();
mlir::Location scopeLoc = getLoc(s.getSourceRange());
builder.create<cir::ScopeOp>(scopeLoc, /*scopeBuilder=*/
[&](mlir::OpBuilder &b, mlir::Location loc) {
LexicalScope lexScope{
*this, loc, builder.getInsertionBlock()};
res = doStmtBuilder();
});

if (res.failed())
return res;

terminateBody(builder, doWhileOp.getBody(), getLoc(s.getEndLoc()));
return mlir::success();
}

mlir::LogicalResult CIRGenFunction::emitWhileStmt(const WhileStmt &s) {
cir::WhileOp whileOp;

// TODO: pass in array of attributes.
auto whileStmtBuilder = [&]() -> mlir::LogicalResult {
mlir::LogicalResult loopRes = mlir::success();
assert(!cir::MissingFeatures::loopInfoStack());
// From LLVM: if there are any cleanups between here and the loop-exit
// scope, create a block to stage a loop exit along.
// We probably already do the right thing because of ScopeOp, but make
// sure we handle all cases.
assert(!cir::MissingFeatures::requiresCleanups());

whileOp = builder.createWhile(
getLoc(s.getSourceRange()),
/*condBuilder=*/
[&](mlir::OpBuilder &b, mlir::Location loc) {
assert(!cir::MissingFeatures::createProfileWeightsForLoop());
assert(!cir::MissingFeatures::emitCondLikelihoodViaExpectIntrinsic());
mlir::Value condVal;
// If the for statement has a condition scope,
// emit the local variable declaration.
if (s.getConditionVariable())
emitDecl(*s.getConditionVariable());
// C99 6.8.5p2/p4: The first substatement is executed if the
// expression compares unequal to 0. The condition must be a
// scalar type.
condVal = evaluateExprAsBool(s.getCond());
builder.createCondition(condVal);
},
/*bodyBuilder=*/
[&](mlir::OpBuilder &b, mlir::Location loc) {
// The scope of the while loop body is a nested scope.
if (emitStmt(s.getBody(), /*useCurrentScope=*/false).failed())
loopRes = mlir::failure();
emitStopPoint(&s);
});
return loopRes;
};

mlir::LogicalResult res = mlir::success();
mlir::Location scopeLoc = getLoc(s.getSourceRange());
builder.create<cir::ScopeOp>(scopeLoc, /*scopeBuilder=*/
[&](mlir::OpBuilder &b, mlir::Location loc) {
LexicalScope lexScope{
*this, loc, builder.getInsertionBlock()};
res = whileStmtBuilder();
});

if (res.failed())
return res;

terminateBody(builder, whileOp.getBody(), getLoc(s.getEndLoc()));
return mlir::success();
}
14 changes: 0 additions & 14 deletions clang/lib/CIR/Dialect/IR/CIRDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -538,20 +538,6 @@ Block *cir::BrCondOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
return nullptr;
}

//===----------------------------------------------------------------------===//
// ForOp
//===----------------------------------------------------------------------===//

void cir::ForOp::getSuccessorRegions(
mlir::RegionBranchPoint point,
llvm::SmallVectorImpl<mlir::RegionSuccessor> &regions) {
LoopOpInterface::getLoopOpSuccessorRegions(*this, point, regions);
}

llvm::SmallVector<Region *> cir::ForOp::getLoopRegions() {
return {&getBody()};
}

//===----------------------------------------------------------------------===//
// GlobalOp
//===----------------------------------------------------------------------===//
Expand Down
Loading
Loading