Skip to content

[CIR] Upstream TernaryOp #137184

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 60 additions & 3 deletions clang/include/clang/CIR/Dialect/IR/CIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -610,9 +610,9 @@ def ConditionOp : CIR_Op<"condition", [
//===----------------------------------------------------------------------===//

def YieldOp : CIR_Op<"yield", [ReturnLike, Terminator,
ParentOneOf<["IfOp", "ScopeOp", "SwitchOp",
"WhileOp", "ForOp", "CaseOp",
"DoWhileOp"]>]> {
ParentOneOf<["CaseOp", "DoWhileOp", "ForOp",
"IfOp", "ScopeOp", "SwitchOp",
"TernaryOp", "WhileOp"]>]> {
let summary = "Represents the default branching behaviour of a region";
let description = [{
The `cir.yield` operation terminates regions on different CIR operations,
Expand Down Expand Up @@ -1462,6 +1462,63 @@ def SelectOp : CIR_Op<"select", [Pure,
}];
}

//===----------------------------------------------------------------------===//
// TernaryOp
//===----------------------------------------------------------------------===//

def TernaryOp : CIR_Op<"ternary",
[DeclareOpInterfaceMethods<RegionBranchOpInterface>,
RecursivelySpeculatable, AutomaticAllocationScope, NoRegionArguments]> {
let summary = "The `cond ? a : b` C/C++ ternary operation";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is wrong. At least in the incubator, we use cir.select to implement the ternary operator.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We simplify cir.ternary to cir.select in FlattenCFG. See the tests here. Without -O1 the first test case gets represented by cir.ternary: https://godbolt.org/z/fz39rh5Kn

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh. I didn't realize we ran CIR simplification based on the opt level. I guess it makes sense in a way, but I'm not thrilled with the idea that this is happening in the front end. This relates to the suggestion I made elsewhere that maybe we should be moving the lowering and transforms out of the clang component.

let description = [{
The `cir.ternary` operation represents C/C++ ternary, much like a `select`
operation. The first argument is a `cir.bool` condition to evaluate, followed
by two regions to execute (true or false). This is different from `cir.if`
since each region is one block sized and the `cir.yield` closing the block
scope should have one argument.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we represent the GNU extension of x = Thing1 ?: Thing2; any differently?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GNU ?: is only partially implemented (we lack glvalue support for opaque value expressions) but it's represented by cir.ternary.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we document it here? One of the 'features' of it IIRC is that it only evaluates Thing1 1x instead of 2x (like in a traditional ternary), so it might be nice to document how that is represented here.


`cir.ternary` also represents the GNU binary conditional operator ?: which
reuses the parent operation for both the condition and the true branch to
evaluate it only once.

Example:

```mlir
// cond = a && b;

%x = cir.ternary (%cond, true_region {
...
cir.yield %a : i32
}, false_region {
...
cir.yield %b : i32
}) -> i32
```
}];
let arguments = (ins CIR_BoolType:$cond);
let regions = (region AnyRegion:$trueRegion,
AnyRegion:$falseRegion);
let results = (outs Optional<CIR_AnyType>:$result);

let skipDefaultBuilders = 1;
let builders = [
OpBuilder<(ins "mlir::Value":$cond,
"llvm::function_ref<void(mlir::OpBuilder &, mlir::Location)>":$trueBuilder,
"llvm::function_ref<void(mlir::OpBuilder &, mlir::Location)>":$falseBuilder)
>
];

// All constraints already verified elsewhere.
let hasVerifier = 0;

let assemblyFormat = [{
`(` $cond `,`
`true` $trueRegion `,`
`false` $falseRegion
`)` `:` functional-type(operands, results) attr-dict
}];
}

//===----------------------------------------------------------------------===//
// GlobalOp
//===----------------------------------------------------------------------===//
Expand Down
43 changes: 43 additions & 0 deletions clang/lib/CIR/Dialect/IR/CIRDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1187,6 +1187,49 @@ LogicalResult cir::BinOp::verify() {
return mlir::success();
}

//===----------------------------------------------------------------------===//
// TernaryOp
//===----------------------------------------------------------------------===//

/// Given the region at `point`, or the parent operation if `point` is None,
/// return the successor regions. These are the regions that may be selected
/// during the flow of control. `operands` is a set of optional attributes that
/// correspond to a constant value for each operand, or null if that operand is
/// not a constant.
void cir::TernaryOp::getSuccessorRegions(
mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
// The `true` and the `false` region branch back to the parent operation.
if (!point.isParent()) {
regions.push_back(RegionSuccessor(this->getODSResults(0)));
return;
}

// When branching from the parent operation, both the true and false
// regions are considered possible successors
regions.push_back(RegionSuccessor(&getTrueRegion()));
regions.push_back(RegionSuccessor(&getFalseRegion()));
}

void cir::TernaryOp::build(
OpBuilder &builder, OperationState &result, Value cond,
function_ref<void(OpBuilder &, Location)> trueBuilder,
function_ref<void(OpBuilder &, Location)> falseBuilder) {
result.addOperands(cond);
OpBuilder::InsertionGuard guard(builder);
Region *trueRegion = result.addRegion();
Block *block = builder.createBlock(trueRegion);
trueBuilder(builder, result.location);
Region *falseRegion = result.addRegion();
builder.createBlock(falseRegion);
falseBuilder(builder, result.location);

auto yield = dyn_cast<YieldOp>(block->getTerminator());
assert((yield && yield.getNumOperands() <= 1) &&
"expected zero or one result type");
if (yield.getNumOperands() == 1)
result.addTypes(TypeRange{yield.getOperandTypes().front()});
}

//===----------------------------------------------------------------------===//
// ShiftOp
//===----------------------------------------------------------------------===//
Expand Down
60 changes: 55 additions & 5 deletions clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -254,10 +254,61 @@ class CIRLoopOpInterfaceFlattening
}
};

class CIRTernaryOpFlattening : public mlir::OpRewritePattern<cir::TernaryOp> {
public:
using OpRewritePattern<cir::TernaryOp>::OpRewritePattern;

mlir::LogicalResult
matchAndRewrite(cir::TernaryOp op,
mlir::PatternRewriter &rewriter) const override {
Location loc = op->getLoc();
Block *condBlock = rewriter.getInsertionBlock();
Block::iterator opPosition = rewriter.getInsertionPoint();
Block *remainingOpsBlock = rewriter.splitBlock(condBlock, opPosition);
llvm::SmallVector<mlir::Location, 2> locs;
// Ternary result is optional, make sure to populate the location only
// when relevant.
if (op->getResultTypes().size())
locs.push_back(loc);
Block *continueBlock =
rewriter.createBlock(remainingOpsBlock, op->getResultTypes(), locs);
rewriter.create<cir::BrOp>(loc, remainingOpsBlock);

Region &trueRegion = op.getTrueRegion();
Block *trueBlock = &trueRegion.front();
mlir::Operation *trueTerminator = trueRegion.back().getTerminator();
rewriter.setInsertionPointToEnd(&trueRegion.back());
auto trueYieldOp = dyn_cast<cir::YieldOp>(trueTerminator);

rewriter.replaceOpWithNewOp<cir::BrOp>(trueYieldOp, trueYieldOp.getArgs(),
continueBlock);
rewriter.inlineRegionBefore(trueRegion, continueBlock);

Block *falseBlock = continueBlock;
Region &falseRegion = op.getFalseRegion();

falseBlock = &falseRegion.front();
mlir::Operation *falseTerminator = falseRegion.back().getTerminator();
rewriter.setInsertionPointToEnd(&falseRegion.back());
auto falseYieldOp = dyn_cast<cir::YieldOp>(falseTerminator);
rewriter.replaceOpWithNewOp<cir::BrOp>(falseYieldOp, falseYieldOp.getArgs(),
continueBlock);
rewriter.inlineRegionBefore(falseRegion, continueBlock);

rewriter.setInsertionPointToEnd(condBlock);
rewriter.create<cir::BrCondOp>(loc, op.getCond(), trueBlock, falseBlock);

rewriter.replaceOp(op, continueBlock->getArguments());

// Ok, we're done!
return mlir::success();
}
};

void populateFlattenCFGPatterns(RewritePatternSet &patterns) {
patterns
.add<CIRIfFlattening, CIRLoopOpInterfaceFlattening, CIRScopeOpFlattening>(
patterns.getContext());
patterns.add<CIRIfFlattening, CIRLoopOpInterfaceFlattening,
CIRScopeOpFlattening, CIRTernaryOpFlattening>(
patterns.getContext());
}

void CIRFlattenCFGPass::runOnOperation() {
Expand All @@ -269,9 +320,8 @@ void CIRFlattenCFGPass::runOnOperation() {
getOperation()->walk<mlir::WalkOrder::PostOrder>([&](Operation *op) {
assert(!cir::MissingFeatures::ifOp());
assert(!cir::MissingFeatures::switchOp());
assert(!cir::MissingFeatures::ternaryOp());
assert(!cir::MissingFeatures::tryOp());
if (isa<IfOp, ScopeOp, LoopOpInterface>(op))
if (isa<IfOp, ScopeOp, LoopOpInterface, TernaryOp>(op))
ops.push_back(op);
});

Expand Down
30 changes: 30 additions & 0 deletions clang/test/CIR/IR/ternary.cir
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// RUN: cir-opt %s | cir-opt | FileCheck %s
!u32i = !cir.int<u, 32>

module {
cir.func @blue(%arg0: !cir.bool) -> !u32i {
%0 = cir.ternary(%arg0, true {
%a = cir.const #cir.int<0> : !u32i
cir.yield %a : !u32i
}, false {
%b = cir.const #cir.int<1> : !u32i
cir.yield %b : !u32i
}) : (!cir.bool) -> !u32i
cir.return %0 : !u32i
}
}

// CHECK: module {

// CHECK: cir.func @blue(%arg0: !cir.bool) -> !u32i {
// CHECK: %0 = cir.ternary(%arg0, true {
// CHECK: %1 = cir.const #cir.int<0> : !u32i
// CHECK: cir.yield %1 : !u32i
// CHECK: }, false {
// CHECK: %1 = cir.const #cir.int<1> : !u32i
// CHECK: cir.yield %1 : !u32i
// CHECK: }) : (!cir.bool) -> !u32i
// CHECK: cir.return %0 : !u32i
// CHECK: }

// CHECK: }
30 changes: 30 additions & 0 deletions clang/test/CIR/Lowering/ternary.cir
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// RUN: cir-translate -cir-to-llvmir --disable-cc-lowering -o %t.ll %s
// RUN: FileCheck --input-file=%t.ll -check-prefix=LLVM %s

!u32i = !cir.int<u, 32>

module {
cir.func @blue(%arg0: !cir.bool) -> !u32i {
%0 = cir.ternary(%arg0, true {
%a = cir.const #cir.int<0> : !u32i
cir.yield %a : !u32i
}, false {
%b = cir.const #cir.int<1> : !u32i
cir.yield %b : !u32i
}) : (!cir.bool) -> !u32i
cir.return %0 : !u32i
}
}

// LLVM-LABEL: define i32 {{.*}}@blue(
// LLVM-SAME: i1 [[PRED:%[[:alnum:]]+]])
// LLVM: br i1 [[PRED]], label %[[B1:[[:alnum:]]+]], label %[[B2:[[:alnum:]]+]]
// LLVM: [[B1]]:
// LLVM: br label %[[M:[[:alnum:]]+]]
// LLVM: [[B2]]:
// LLVM: br label %[[M]]
// LLVM: [[M]]:
// LLVM: [[R:%[[:alnum:]]+]] = phi i32 [ 1, %[[B2]] ], [ 0, %[[B1]] ]
// LLVM: br label %[[B3:[[:alnum:]]+]]
// LLVM: [[B3]]:
// LLVM: ret i32 [[R]]
68 changes: 68 additions & 0 deletions clang/test/CIR/Transforms/ternary.cir
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// RUN: cir-opt %s -cir-flatten-cfg -o - | FileCheck %s

!s32i = !cir.int<s, 32>

module {
cir.func @foo(%arg0: !s32i) -> !s32i {
%0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["y", init] {alignment = 4 : i64}
%1 = cir.alloca !s32i, !cir.ptr<!s32i>, ["__retval"] {alignment = 4 : i64}
cir.store %arg0, %0 : !s32i, !cir.ptr<!s32i>
%2 = cir.load %0 : !cir.ptr<!s32i>, !s32i
%3 = cir.const #cir.int<0> : !s32i
%4 = cir.cmp(gt, %2, %3) : !s32i, !cir.bool
%5 = cir.ternary(%4, true {
%7 = cir.const #cir.int<3> : !s32i
cir.yield %7 : !s32i
}, false {
%7 = cir.const #cir.int<5> : !s32i
cir.yield %7 : !s32i
}) : (!cir.bool) -> !s32i
cir.store %5, %1 : !s32i, !cir.ptr<!s32i>
%6 = cir.load %1 : !cir.ptr<!s32i>, !s32i
cir.return %6 : !s32i
}

// CHECK: cir.func @foo(%arg0: !s32i) -> !s32i {
// CHECK: %0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["y", init] {alignment = 4 : i64}
// CHECK: %1 = cir.alloca !s32i, !cir.ptr<!s32i>, ["__retval"] {alignment = 4 : i64}
// CHECK: cir.store %arg0, %0 : !s32i, !cir.ptr<!s32i>
// CHECK: %2 = cir.load %0 : !cir.ptr<!s32i>, !s32i
// CHECK: %3 = cir.const #cir.int<0> : !s32i
// CHECK: %4 = cir.cmp(gt, %2, %3) : !s32i, !cir.bool
// CHECK: cir.brcond %4 ^bb1, ^bb2
// CHECK: ^bb1: // pred: ^bb0
// CHECK: %5 = cir.const #cir.int<3> : !s32i
// CHECK: cir.br ^bb3(%5 : !s32i)
// CHECK: ^bb2: // pred: ^bb0
// CHECK: %6 = cir.const #cir.int<5> : !s32i
// CHECK: cir.br ^bb3(%6 : !s32i)
// CHECK: ^bb3(%7: !s32i): // 2 preds: ^bb1, ^bb2
// CHECK: cir.br ^bb4
// CHECK: ^bb4: // pred: ^bb3
// CHECK: cir.store %7, %1 : !s32i, !cir.ptr<!s32i>
// CHECK: %8 = cir.load %1 : !cir.ptr<!s32i>, !s32i
// CHECK: cir.return %8 : !s32i
// CHECK: }

cir.func @foo2(%arg0: !cir.bool) {
cir.ternary(%arg0, true {
cir.yield
}, false {
cir.yield
}) : (!cir.bool) -> ()
cir.return
}

// CHECK: cir.func @foo2(%arg0: !cir.bool) {
// CHECK: cir.brcond %arg0 ^bb1, ^bb2
// CHECK: ^bb1: // pred: ^bb0
// CHECK: cir.br ^bb3
// CHECK: ^bb2: // pred: ^bb0
// CHECK: cir.br ^bb3
// CHECK: ^bb3: // 2 preds: ^bb1, ^bb2
// CHECK: cir.br ^bb4
// CHECK: ^bb4: // pred: ^bb3
// CHECK: cir.return
// CHECK: }

}