Skip to content

[CIR] Implement switch case simplify #140649

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion clang/include/clang/CIR/MissingFeatures.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,6 @@ struct MissingFeatures {
static bool opUnaryPromotionType() { return false; }

// SwitchOp handling
static bool foldCascadingCases() { return false; }
static bool foldRangeCase() { return false; }

// Clang early optimizations or things defered to LLVM lowering.
Expand Down
6 changes: 0 additions & 6 deletions clang/lib/CIR/CodeGen/CIRGenStmt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -531,12 +531,6 @@ mlir::LogicalResult CIRGenFunction::emitCaseStmt(const CaseStmt &s,
value = builder.getArrayAttr({cir::IntAttr::get(condType, intVal),
cir::IntAttr::get(condType, endVal)});
kind = cir::CaseOpKind::Range;

// We don't currently fold case range statements with other case statements.
// TODO(cir): Add this capability. Folding these cases is going to be
// implemented in CIRSimplify when it is upstreamed.
assert(!cir::MissingFeatures::foldRangeCase());
assert(!cir::MissingFeatures::foldCascadingCases());
} else {
value = builder.getArrayAttr({cir::IntAttr::get(condType, intVal)});
kind = cir::CaseOpKind::Equal;
Expand Down
106 changes: 104 additions & 2 deletions clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,107 @@ struct SimplifySelect : public OpRewritePattern<SelectOp> {
}
};

/// Simplify `cir.switch` operations by folding cascading cases
/// into a single `cir.case` with the `anyof` kind.
///
/// This pattern identifies cascading cases within a `cir.switch` operation.
/// Cascading cases are defined as consecutive `cir.case` operations of kind
/// `equal`, each containing a single `cir.yield` operation in their body.
///
/// The pattern merges these cascading cases into a single `cir.case` operation
/// with kind `anyof`, aggregating all the case values.
///
/// The merging process continues until a `cir.case` with a different body
/// (e.g., containing `cir.break` or compound stmt) is encountered, which
/// breaks the chain.
///
/// Example:
///
/// Before:
/// cir.case equal, [#cir.int<0> : !s32i] {
/// cir.yield
/// }
/// cir.case equal, [#cir.int<1> : !s32i] {
/// cir.yield
/// }
/// cir.case equal, [#cir.int<2> : !s32i] {
/// cir.break
/// }
///
/// After applying SimplifySwitch:
/// cir.case anyof, [#cir.int<0> : !s32i, #cir.int<1> : !s32i, #cir.int<2> :
/// !s32i] {
/// cir.break
/// }
struct SimplifySwitch : public OpRewritePattern<SwitchOp> {
using OpRewritePattern<SwitchOp>::OpRewritePattern;
LogicalResult matchAndRewrite(SwitchOp op,
PatternRewriter &rewriter) const override {

LogicalResult changed = mlir::failure();
SmallVector<CaseOp, 8> cases;
SmallVector<CaseOp, 4> cascadingCases;
SmallVector<mlir::Attribute, 4> cascadingCaseValues;

op.collectCases(cases);
if (cases.empty())
return mlir::failure();

auto flushMergedOps = [&]() {
for (CaseOp &c : cascadingCases)
rewriter.eraseOp(c);
cascadingCases.clear();
cascadingCaseValues.clear();
};

auto mergeCascadingInto = [&](CaseOp &target) {
rewriter.modifyOpInPlace(target, [&]() {
target.setValueAttr(rewriter.getArrayAttr(cascadingCaseValues));
target.setKind(CaseOpKind::Anyof);
});
changed = mlir::success();
};

for (CaseOp c : cases) {
cir::CaseOpKind kind = c.getKind();
if (kind == cir::CaseOpKind::Equal &&
isa<YieldOp>(c.getCaseRegion().front().front())) {
// If the case contains only a YieldOp, collect it for cascading merge
cascadingCases.push_back(c);
cascadingCaseValues.push_back(c.getValue()[0]);
} else if (kind == cir::CaseOpKind::Equal && !cascadingCases.empty()) {
// merge previously collected cascading cases
cascadingCaseValues.push_back(c.getValue()[0]);
mergeCascadingInto(c);
flushMergedOps();
} else if (kind != cir::CaseOpKind::Equal && cascadingCases.size() > 1) {
// If a Default, Anyof or Range case is found and there are previous
// cascading cases, merge all of them into the last cascading case.
// We don't currently fold case range statements with other case
// statements.
assert(!cir::MissingFeatures::foldRangeCase());
CaseOp lastCascadingCase = cascadingCases.back();
mergeCascadingInto(lastCascadingCase);
cascadingCases.pop_back();
flushMergedOps();
} else {
cascadingCases.clear();
cascadingCaseValues.clear();
}
}

// Edge case: all cases are simple cascading cases
if (cascadingCases.size() == cases.size()) {
CaseOp lastCascadingCase = cascadingCases.back();
mergeCascadingInto(lastCascadingCase);
cascadingCases.pop_back();
flushMergedOps();
}

return changed;
}
};

//===----------------------------------------------------------------------===//
// CIRSimplifyPass
//===----------------------------------------------------------------------===//
Expand All @@ -173,7 +274,8 @@ void populateMergeCleanupPatterns(RewritePatternSet &patterns) {
// clang-format off
patterns.add<
SimplifyTernary,
SimplifySelect
SimplifySelect,
SimplifySwitch
>(patterns.getContext());
// clang-format on
}
Expand All @@ -186,7 +288,7 @@ void CIRSimplifyPass::runOnOperation() {
// Collect operations to apply patterns.
llvm::SmallVector<Operation *, 16> ops;
getOperation()->walk([&](Operation *op) {
if (isa<TernaryOp, SelectOp>(op))
if (isa<TernaryOp, SelectOp, SwitchOp>(op))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you double check if the test passes without issues if -DMLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS=ON is used while building clang? We currently having incubator issues with this and probably best to make sure we don't introduce them here if possible.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have built Clang with the -DMLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS=ON flag, and this pull request passes successfully. However, I have identified two failing tests:


********************
Failed Tests (2):
  Clang :: CIR/CodeGen/loop.cpp
  Clang :: CIR/Transforms/switch.cir
********************

The CIR/Transforms/switch.cir test fails when applying the -cir-flatten-cfg pass. I'm going to check if this test is also failing in the incubator.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just checked and it also fails in the incubator.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for walking the extra leg here, can you create an issue for this so we can take a look later? (cc @xlauko which pointed this recently).

ops.push_back(op);
});

Expand Down
196 changes: 196 additions & 0 deletions clang/test/CIR/Transforms/switch-fold.cir
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
// RUN: cir-opt -cir-canonicalize -cir-simplify -o %t.cir %s
// RUN: FileCheck --input-file=%t.cir %s

!s32i = !cir.int<s, 32>

module {
cir.func @foldCascade(%arg0: !s32i) {
%0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["a", init] {alignment = 4 : i64}
cir.store %arg0, %0 : !s32i, !cir.ptr<!s32i>
cir.scope {
%1 = cir.load %0 : !cir.ptr<!s32i>, !s32i
cir.switch (%1 : !s32i) {
cir.case(equal, [#cir.int<1> : !s32i]) {
cir.yield
}
cir.case(equal, [#cir.int<2> : !s32i]) {
cir.yield
}
cir.case(equal, [#cir.int<3> : !s32i]) {
%2 = cir.const #cir.int<2> : !s32i
cir.store %2, %0 : !s32i, !cir.ptr<!s32i>
cir.break
}
cir.yield
}
}
cir.return
}
//CHECK: cir.func @foldCascade
//CHECK: cir.switch (%[[COND:.*]] : !s32i) {
//CHECK-NEXT: cir.case(anyof, [#cir.int<1> : !s32i, #cir.int<2> : !s32i, #cir.int<3> : !s32i]) {
//CHECK-NEXT: %[[TWO:.*]] = cir.const #cir.int<2> : !s32i
//CHECK-NEXT: cir.store %[[TWO]], %[[ARG0:.*]] : !s32i, !cir.ptr<!s32i>
//CHECK-NEXT: cir.break
//CHECK-NEXT: }
//CHECK-NEXT: cir.yield
//CHECK-NEXT: }

cir.func @foldCascade2(%arg0: !s32i) {
%0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["a", init] {alignment = 4 : i64}
cir.store %arg0, %0 : !s32i, !cir.ptr<!s32i>
cir.scope {
%1 = cir.load %0 : !cir.ptr<!s32i>, !s32i
cir.switch (%1 : !s32i) {
cir.case(equal, [#cir.int<0> : !s32i]) {
cir.yield
}
cir.case(equal, [#cir.int<2> : !s32i]) {
cir.yield
}
cir.case(equal, [#cir.int<4> : !s32i]) {
cir.break
}
cir.case(equal, [#cir.int<1> : !s32i]) {
cir.yield
}
cir.case(equal, [#cir.int<3> : !s32i]) {
cir.yield
}
cir.case(equal, [#cir.int<5> : !s32i]) {
cir.break
}
cir.yield
}
}
cir.return
}
//CHECK: @foldCascade2
//CHECK: cir.switch (%[[COND2:.*]] : !s32i) {
//CHECK: cir.case(anyof, [#cir.int<0> : !s32i, #cir.int<2> : !s32i, #cir.int<4> : !s32i]) {
//CHECK: cir.break
//cehck: }
//CHECK: cir.case(anyof, [#cir.int<1> : !s32i, #cir.int<3> : !s32i, #cir.int<5> : !s32i]) {
//CHECK: cir.break
//CHECK: }
//CHECK: cir.yield
//CHECK: }
cir.func @foldCascade3(%arg0: !s32i ) {
%0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["a", init] {alignment = 4 : i64}
cir.store %arg0, %0 : !s32i, !cir.ptr<!s32i>
cir.scope {
%1 = cir.alloca !s32i, !cir.ptr<!s32i>, ["x"] {alignment = 4 : i64}
%2 = cir.load %0 : !cir.ptr<!s32i>, !s32i
cir.switch (%2 : !s32i) {
cir.case(equal, [#cir.int<0> : !s32i]) {
cir.yield
}
cir.case(equal, [#cir.int<1> : !s32i]) {
cir.yield
}
cir.case(equal, [#cir.int<2> : !s32i]) {
cir.yield
}
cir.case(equal, [#cir.int<3> : !s32i]) {
cir.yield
}
cir.case(equal, [#cir.int<4> : !s32i]) {
cir.yield
}
cir.case(equal, [#cir.int<5> : !s32i]) {
cir.break
}
cir.yield
}
}
cir.return
}
//CHECK: cir.func @foldCascade3
//CHECK: cir.switch (%[[COND3:.*]] : !s32i) {
//CHECK: cir.case(anyof, [#cir.int<0> : !s32i, #cir.int<1> : !s32i, #cir.int<2> : !s32i, #cir.int<3> : !s32i, #cir.int<4> : !s32i, #cir.int<5> : !s32i]) {
//CHECK: cir.break
//CHECK: }
//CHECK: cir.yield
//CHECK: }
cir.func @foldCascadeWithDefault(%arg0: !s32i ) {
%0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["a", init] {alignment = 4 : i64}
cir.store %arg0, %0 : !s32i, !cir.ptr<!s32i>
cir.scope {
%1 = cir.load %0 : !cir.ptr<!s32i>, !s32i
cir.switch (%1 : !s32i) {
cir.case(equal, [#cir.int<3> : !s32i]) {
cir.break
}
cir.case(equal, [#cir.int<4> : !s32i]) {
cir.yield
}
cir.case(equal, [#cir.int<5> : !s32i]) {
cir.yield
}
cir.case(default, []) {
cir.yield
}
cir.case(equal, [#cir.int<6> : !s32i]) {
cir.yield
}
cir.case(equal, [#cir.int<7> : !s32i]) {
cir.break
}
cir.yield
}
}
cir.return
}
//CHECK: cir.func @foldCascadeWithDefault
//CHECK: cir.switch (%[[COND:.*]] : !s32i) {
//CHECK: cir.case(equal, [#cir.int<3> : !s32i]) {
//CHECK: cir.break
//CHECK: }
//CHECK: cir.case(anyof, [#cir.int<4> : !s32i, #cir.int<5> : !s32i]) {
//CHECK: cir.yield
//CHECK: }
//CHECK: cir.case(default, []) {
//CHECK: cir.yield
//CHECK: }
//CHECK: cir.case(anyof, [#cir.int<6> : !s32i, #cir.int<7> : !s32i]) {
//CHECK: cir.break
//CHECK: }
//CHECK: cir.yield
//CHECK: }
cir.func @foldAllCascade(%arg0: !s32i ) {
%0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["a", init] {alignment = 4 : i64}
cir.store %arg0, %0 : !s32i, !cir.ptr<!s32i>
cir.scope {
%1 = cir.load %0 : !cir.ptr<!s32i>, !s32i
cir.switch (%1 : !s32i) {
cir.case(equal, [#cir.int<0> : !s32i]) {
cir.yield
}
cir.case(equal, [#cir.int<1> : !s32i]) {
cir.yield
}
cir.case(equal, [#cir.int<2> : !s32i]) {
cir.yield
}
cir.case(equal, [#cir.int<3> : !s32i]) {
cir.yield
}
cir.case(equal, [#cir.int<4> : !s32i]) {
cir.yield
}
cir.case(equal, [#cir.int<5> : !s32i]) {
cir.yield
}
cir.yield
}
}
cir.return
}
//CHECK: cir.func @foldAllCascade
//CHECK: cir.switch (%[[COND:.*]] : !s32i) {
//CHECK: cir.case(anyof, [#cir.int<0> : !s32i, #cir.int<1> : !s32i, #cir.int<2> : !s32i, #cir.int<3> : !s32i, #cir.int<4> : !s32i, #cir.int<5> : !s32i]) {
//CHECK: cir.yield
//CHECK: }
//CHECK: cir.yield
//CHECK: }
}