-
Notifications
You must be signed in to change notification settings - Fork 14.4k
[CIR] Implement switch case simplify #140649
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -159,6 +159,107 @@ struct SimplifySelect : public OpRewritePattern<SelectOp> { | |
} | ||
}; | ||
|
||
/// Simplify `cir.switch` operations by folding cascading cases | ||
/// into a single `cir.case` with the `anyof` kind. | ||
/// | ||
/// This pattern identifies cascading cases within a `cir.switch` operation. | ||
/// Cascading cases are defined as consecutive `cir.case` operations of kind | ||
/// `equal`, each containing a single `cir.yield` operation in their body. | ||
/// | ||
/// The pattern merges these cascading cases into a single `cir.case` operation | ||
/// with kind `anyof`, aggregating all the case values. | ||
/// | ||
/// The merging process continues until a `cir.case` with a different body | ||
/// (e.g., containing `cir.break` or compound stmt) is encountered, which | ||
/// breaks the chain. | ||
/// | ||
/// Example: | ||
/// | ||
/// Before: | ||
/// cir.case equal, [#cir.int<0> : !s32i] { | ||
/// cir.yield | ||
/// } | ||
/// cir.case equal, [#cir.int<1> : !s32i] { | ||
/// cir.yield | ||
/// } | ||
/// cir.case equal, [#cir.int<2> : !s32i] { | ||
/// cir.break | ||
/// } | ||
/// | ||
/// After applying SimplifySwitch: | ||
/// cir.case anyof, [#cir.int<0> : !s32i, #cir.int<1> : !s32i, #cir.int<2> : | ||
/// !s32i] { | ||
/// cir.break | ||
/// } | ||
struct SimplifySwitch : public OpRewritePattern<SwitchOp> { | ||
using OpRewritePattern<SwitchOp>::OpRewritePattern; | ||
LogicalResult matchAndRewrite(SwitchOp op, | ||
PatternRewriter &rewriter) const override { | ||
|
||
LogicalResult changed = mlir::failure(); | ||
SmallVector<CaseOp, 8> cases; | ||
SmallVector<CaseOp, 4> cascadingCases; | ||
SmallVector<mlir::Attribute, 4> cascadingCaseValues; | ||
|
||
op.collectCases(cases); | ||
if (cases.empty()) | ||
return mlir::failure(); | ||
|
||
auto flushMergedOps = [&]() { | ||
for (CaseOp &c : cascadingCases) | ||
rewriter.eraseOp(c); | ||
cascadingCases.clear(); | ||
cascadingCaseValues.clear(); | ||
}; | ||
|
||
auto mergeCascadingInto = [&](CaseOp &target) { | ||
rewriter.modifyOpInPlace(target, [&]() { | ||
target.setValueAttr(rewriter.getArrayAttr(cascadingCaseValues)); | ||
target.setKind(CaseOpKind::Anyof); | ||
}); | ||
changed = mlir::success(); | ||
}; | ||
|
||
for (CaseOp c : cases) { | ||
cir::CaseOpKind kind = c.getKind(); | ||
if (kind == cir::CaseOpKind::Equal && | ||
isa<YieldOp>(c.getCaseRegion().front().front())) { | ||
// If the case contains only a YieldOp, collect it for cascading merge | ||
cascadingCases.push_back(c); | ||
cascadingCaseValues.push_back(c.getValue()[0]); | ||
} else if (kind == cir::CaseOpKind::Equal && !cascadingCases.empty()) { | ||
// merge previously collected cascading cases | ||
cascadingCaseValues.push_back(c.getValue()[0]); | ||
mergeCascadingInto(c); | ||
flushMergedOps(); | ||
} else if (kind != cir::CaseOpKind::Equal && cascadingCases.size() > 1) { | ||
// If a Default, Anyof or Range case is found and there are previous | ||
// cascading cases, merge all of them into the last cascading case. | ||
// We don't currently fold case range statements with other case | ||
// statements. | ||
assert(!cir::MissingFeatures::foldRangeCase()); | ||
CaseOp lastCascadingCase = cascadingCases.back(); | ||
mergeCascadingInto(lastCascadingCase); | ||
cascadingCases.pop_back(); | ||
flushMergedOps(); | ||
} else { | ||
cascadingCases.clear(); | ||
andykaylor marked this conversation as resolved.
Show resolved
Hide resolved
|
||
cascadingCaseValues.clear(); | ||
} | ||
} | ||
|
||
// Edge case: all cases are simple cascading cases | ||
if (cascadingCases.size() == cases.size()) { | ||
CaseOp lastCascadingCase = cascadingCases.back(); | ||
mergeCascadingInto(lastCascadingCase); | ||
cascadingCases.pop_back(); | ||
flushMergedOps(); | ||
} | ||
|
||
return changed; | ||
} | ||
}; | ||
|
||
//===----------------------------------------------------------------------===// | ||
// CIRSimplifyPass | ||
//===----------------------------------------------------------------------===// | ||
|
@@ -173,7 +274,8 @@ void populateMergeCleanupPatterns(RewritePatternSet &patterns) { | |
// clang-format off | ||
patterns.add< | ||
SimplifyTernary, | ||
SimplifySelect | ||
SimplifySelect, | ||
SimplifySwitch | ||
>(patterns.getContext()); | ||
// clang-format on | ||
} | ||
|
@@ -186,7 +288,7 @@ void CIRSimplifyPass::runOnOperation() { | |
// Collect operations to apply patterns. | ||
llvm::SmallVector<Operation *, 16> ops; | ||
getOperation()->walk([&](Operation *op) { | ||
if (isa<TernaryOp, SelectOp>(op)) | ||
if (isa<TernaryOp, SelectOp, SwitchOp>(op)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you double check if the test passes without issues if There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have built Clang with the
The There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just checked and it also fails in the incubator. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for walking the extra leg here, can you create an issue for this so we can take a look later? (cc @xlauko which pointed this recently). |
||
ops.push_back(op); | ||
}); | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,196 @@ | ||
// RUN: cir-opt -cir-canonicalize -cir-simplify -o %t.cir %s | ||
// RUN: FileCheck --input-file=%t.cir %s | ||
|
||
!s32i = !cir.int<s, 32> | ||
|
||
module { | ||
cir.func @foldCascade(%arg0: !s32i) { | ||
%0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["a", init] {alignment = 4 : i64} | ||
cir.store %arg0, %0 : !s32i, !cir.ptr<!s32i> | ||
cir.scope { | ||
%1 = cir.load %0 : !cir.ptr<!s32i>, !s32i | ||
cir.switch (%1 : !s32i) { | ||
cir.case(equal, [#cir.int<1> : !s32i]) { | ||
cir.yield | ||
} | ||
cir.case(equal, [#cir.int<2> : !s32i]) { | ||
cir.yield | ||
} | ||
cir.case(equal, [#cir.int<3> : !s32i]) { | ||
%2 = cir.const #cir.int<2> : !s32i | ||
cir.store %2, %0 : !s32i, !cir.ptr<!s32i> | ||
cir.break | ||
} | ||
cir.yield | ||
} | ||
} | ||
cir.return | ||
} | ||
//CHECK: cir.func @foldCascade | ||
//CHECK: cir.switch (%[[COND:.*]] : !s32i) { | ||
//CHECK-NEXT: cir.case(anyof, [#cir.int<1> : !s32i, #cir.int<2> : !s32i, #cir.int<3> : !s32i]) { | ||
//CHECK-NEXT: %[[TWO:.*]] = cir.const #cir.int<2> : !s32i | ||
//CHECK-NEXT: cir.store %[[TWO]], %[[ARG0:.*]] : !s32i, !cir.ptr<!s32i> | ||
//CHECK-NEXT: cir.break | ||
//CHECK-NEXT: } | ||
//CHECK-NEXT: cir.yield | ||
//CHECK-NEXT: } | ||
|
||
cir.func @foldCascade2(%arg0: !s32i) { | ||
%0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["a", init] {alignment = 4 : i64} | ||
cir.store %arg0, %0 : !s32i, !cir.ptr<!s32i> | ||
cir.scope { | ||
%1 = cir.load %0 : !cir.ptr<!s32i>, !s32i | ||
cir.switch (%1 : !s32i) { | ||
cir.case(equal, [#cir.int<0> : !s32i]) { | ||
andykaylor marked this conversation as resolved.
Show resolved
Hide resolved
|
||
cir.yield | ||
} | ||
cir.case(equal, [#cir.int<2> : !s32i]) { | ||
cir.yield | ||
} | ||
cir.case(equal, [#cir.int<4> : !s32i]) { | ||
cir.break | ||
} | ||
cir.case(equal, [#cir.int<1> : !s32i]) { | ||
cir.yield | ||
} | ||
cir.case(equal, [#cir.int<3> : !s32i]) { | ||
cir.yield | ||
} | ||
cir.case(equal, [#cir.int<5> : !s32i]) { | ||
cir.break | ||
} | ||
cir.yield | ||
} | ||
} | ||
cir.return | ||
} | ||
//CHECK: @foldCascade2 | ||
//CHECK: cir.switch (%[[COND2:.*]] : !s32i) { | ||
//CHECK: cir.case(anyof, [#cir.int<0> : !s32i, #cir.int<2> : !s32i, #cir.int<4> : !s32i]) { | ||
//CHECK: cir.break | ||
//cehck: } | ||
//CHECK: cir.case(anyof, [#cir.int<1> : !s32i, #cir.int<3> : !s32i, #cir.int<5> : !s32i]) { | ||
//CHECK: cir.break | ||
//CHECK: } | ||
//CHECK: cir.yield | ||
//CHECK: } | ||
cir.func @foldCascade3(%arg0: !s32i ) { | ||
%0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["a", init] {alignment = 4 : i64} | ||
cir.store %arg0, %0 : !s32i, !cir.ptr<!s32i> | ||
cir.scope { | ||
%1 = cir.alloca !s32i, !cir.ptr<!s32i>, ["x"] {alignment = 4 : i64} | ||
%2 = cir.load %0 : !cir.ptr<!s32i>, !s32i | ||
cir.switch (%2 : !s32i) { | ||
cir.case(equal, [#cir.int<0> : !s32i]) { | ||
cir.yield | ||
} | ||
cir.case(equal, [#cir.int<1> : !s32i]) { | ||
cir.yield | ||
} | ||
cir.case(equal, [#cir.int<2> : !s32i]) { | ||
cir.yield | ||
} | ||
cir.case(equal, [#cir.int<3> : !s32i]) { | ||
cir.yield | ||
} | ||
cir.case(equal, [#cir.int<4> : !s32i]) { | ||
cir.yield | ||
} | ||
cir.case(equal, [#cir.int<5> : !s32i]) { | ||
cir.break | ||
} | ||
cir.yield | ||
} | ||
} | ||
cir.return | ||
} | ||
//CHECK: cir.func @foldCascade3 | ||
//CHECK: cir.switch (%[[COND3:.*]] : !s32i) { | ||
//CHECK: cir.case(anyof, [#cir.int<0> : !s32i, #cir.int<1> : !s32i, #cir.int<2> : !s32i, #cir.int<3> : !s32i, #cir.int<4> : !s32i, #cir.int<5> : !s32i]) { | ||
//CHECK: cir.break | ||
//CHECK: } | ||
//CHECK: cir.yield | ||
//CHECK: } | ||
cir.func @foldCascadeWithDefault(%arg0: !s32i ) { | ||
%0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["a", init] {alignment = 4 : i64} | ||
cir.store %arg0, %0 : !s32i, !cir.ptr<!s32i> | ||
cir.scope { | ||
%1 = cir.load %0 : !cir.ptr<!s32i>, !s32i | ||
cir.switch (%1 : !s32i) { | ||
cir.case(equal, [#cir.int<3> : !s32i]) { | ||
cir.break | ||
} | ||
cir.case(equal, [#cir.int<4> : !s32i]) { | ||
cir.yield | ||
} | ||
cir.case(equal, [#cir.int<5> : !s32i]) { | ||
cir.yield | ||
} | ||
cir.case(default, []) { | ||
cir.yield | ||
} | ||
cir.case(equal, [#cir.int<6> : !s32i]) { | ||
cir.yield | ||
} | ||
cir.case(equal, [#cir.int<7> : !s32i]) { | ||
cir.break | ||
} | ||
cir.yield | ||
} | ||
} | ||
cir.return | ||
} | ||
//CHECK: cir.func @foldCascadeWithDefault | ||
//CHECK: cir.switch (%[[COND:.*]] : !s32i) { | ||
//CHECK: cir.case(equal, [#cir.int<3> : !s32i]) { | ||
//CHECK: cir.break | ||
//CHECK: } | ||
//CHECK: cir.case(anyof, [#cir.int<4> : !s32i, #cir.int<5> : !s32i]) { | ||
//CHECK: cir.yield | ||
//CHECK: } | ||
//CHECK: cir.case(default, []) { | ||
//CHECK: cir.yield | ||
//CHECK: } | ||
//CHECK: cir.case(anyof, [#cir.int<6> : !s32i, #cir.int<7> : !s32i]) { | ||
//CHECK: cir.break | ||
//CHECK: } | ||
//CHECK: cir.yield | ||
//CHECK: } | ||
cir.func @foldAllCascade(%arg0: !s32i ) { | ||
%0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["a", init] {alignment = 4 : i64} | ||
cir.store %arg0, %0 : !s32i, !cir.ptr<!s32i> | ||
cir.scope { | ||
%1 = cir.load %0 : !cir.ptr<!s32i>, !s32i | ||
cir.switch (%1 : !s32i) { | ||
cir.case(equal, [#cir.int<0> : !s32i]) { | ||
cir.yield | ||
} | ||
cir.case(equal, [#cir.int<1> : !s32i]) { | ||
cir.yield | ||
} | ||
cir.case(equal, [#cir.int<2> : !s32i]) { | ||
cir.yield | ||
} | ||
cir.case(equal, [#cir.int<3> : !s32i]) { | ||
cir.yield | ||
} | ||
cir.case(equal, [#cir.int<4> : !s32i]) { | ||
cir.yield | ||
} | ||
cir.case(equal, [#cir.int<5> : !s32i]) { | ||
cir.yield | ||
} | ||
cir.yield | ||
} | ||
} | ||
cir.return | ||
} | ||
//CHECK: cir.func @foldAllCascade | ||
//CHECK: cir.switch (%[[COND:.*]] : !s32i) { | ||
//CHECK: cir.case(anyof, [#cir.int<0> : !s32i, #cir.int<1> : !s32i, #cir.int<2> : !s32i, #cir.int<3> : !s32i, #cir.int<4> : !s32i, #cir.int<5> : !s32i]) { | ||
//CHECK: cir.yield | ||
//CHECK: } | ||
//CHECK: cir.yield | ||
//CHECK: } | ||
} |
Uh oh!
There was an error while loading. Please reload this page.