Skip to content

Commit f74d893

Browse files
[CIR] Upstream support for switch statements case kinds (#138003)
This introduces support for the following cir::case kinds: - `Equal` - `AnyOf` - `Range`
1 parent 16107c8 commit f74d893

File tree

4 files changed

+442
-43
lines changed

4 files changed

+442
-43
lines changed

clang/include/clang/CIR/MissingFeatures.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,10 @@ struct MissingFeatures {
111111
// Unary operator handling
112112
static bool opUnaryPromotionType() { return false; }
113113

114+
// SwitchOp handling
115+
static bool foldCascadingCases() { return false; }
116+
static bool foldRangeCase() { return false; }
117+
114118
// Clang early optimizations or things defered to LLVM lowering.
115119
static bool mayHaveIntegerOverflow() { return false; }
116120
static bool shouldReverseUnaryCondOnBoolExpr() { return false; }
@@ -176,7 +180,6 @@ struct MissingFeatures {
176180
static bool targetSpecificCXXABI() { return false; }
177181
static bool moduleNameHash() { return false; }
178182
static bool setDSOLocal() { return false; }
179-
static bool foldCaseStmt() { return false; }
180183
static bool constantFoldSwitchStatement() { return false; }
181184
static bool cudaSupport() { return false; }
182185
static bool maybeHandleStaticInExternC() { return false; }

clang/lib/CIR/CodeGen/CIRGenFunction.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -581,6 +581,10 @@ class CIRGenFunction : public CIRGenTypeCache {
581581
mlir::LogicalResult emitDeclStmt(const clang::DeclStmt &s);
582582
LValue emitDeclRefLValue(const clang::DeclRefExpr *e);
583583

584+
mlir::LogicalResult emitDefaultStmt(const clang::DefaultStmt &s,
585+
mlir::Type condType,
586+
bool buildingTopLevelCase);
587+
584588
/// Emit an `if` on a boolean condition to the specified blocks.
585589
/// FIXME: Based on the condition, this might try to simplify the codegen of
586590
/// the conditional based on the branch.

clang/lib/CIR/CodeGen/CIRGenStmt.cpp

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,7 @@ mlir::LogicalResult CIRGenFunction::emitSimpleStmt(const Stmt *s,
254254
case Stmt::NullStmtClass:
255255
break;
256256
case Stmt::CaseStmtClass:
257+
case Stmt::DefaultStmtClass:
257258
// If we reached here, we must not handling a switch case in the top level.
258259
return emitSwitchCase(cast<SwitchCase>(*s),
259260
/*buildingTopLevelCase=*/false);
@@ -458,7 +459,7 @@ CIRGenFunction::emitCaseDefaultCascade(const T *stmt, mlir::Type condType,
458459
if (isa<DefaultStmt>(sub) && isa<CaseStmt>(stmt)) {
459460
subStmtKind = SubStmtKind::Default;
460461
builder.createYield(loc);
461-
} else if (isa<CaseStmt>(sub) && isa<DefaultStmt>(stmt)) {
462+
} else if (isa<CaseStmt>(sub) && isa<DefaultStmt, CaseStmt>(stmt)) {
462463
subStmtKind = SubStmtKind::Case;
463464
builder.createYield(loc);
464465
} else {
@@ -503,8 +504,8 @@ CIRGenFunction::emitCaseDefaultCascade(const T *stmt, mlir::Type condType,
503504
if (subStmtKind == SubStmtKind::Case) {
504505
result = emitCaseStmt(*cast<CaseStmt>(sub), condType, buildingTopLevelCase);
505506
} else if (subStmtKind == SubStmtKind::Default) {
506-
getCIRGenModule().errorNYI(sub->getSourceRange(), "Default case");
507-
return mlir::failure();
507+
result = emitDefaultStmt(*cast<DefaultStmt>(sub), condType,
508+
buildingTopLevelCase);
508509
} else if (buildingTopLevelCase) {
509510
// If we're building a top level case, try to restore the insert point to
510511
// the case we're building, then we can attach more random stmts to the
@@ -518,19 +519,40 @@ CIRGenFunction::emitCaseDefaultCascade(const T *stmt, mlir::Type condType,
518519
mlir::LogicalResult CIRGenFunction::emitCaseStmt(const CaseStmt &s,
519520
mlir::Type condType,
520521
bool buildingTopLevelCase) {
522+
cir::CaseOpKind kind;
523+
mlir::ArrayAttr value;
521524
llvm::APSInt intVal = s.getLHS()->EvaluateKnownConstInt(getContext());
522-
SmallVector<mlir::Attribute, 1> caseEltValueListAttr;
523-
caseEltValueListAttr.push_back(cir::IntAttr::get(condType, intVal));
524-
mlir::ArrayAttr value = builder.getArrayAttr(caseEltValueListAttr);
525-
if (s.getRHS()) {
526-
getCIRGenModule().errorNYI(s.getSourceRange(), "SwitchOp range kind");
527-
return mlir::failure();
525+
526+
// If the case statement has an RHS value, it is representing a GNU
527+
// case range statement, where LHS is the beginning of the range
528+
// and RHS is the end of the range.
529+
if (const Expr *rhs = s.getRHS()) {
530+
llvm::APSInt endVal = rhs->EvaluateKnownConstInt(getContext());
531+
value = builder.getArrayAttr({cir::IntAttr::get(condType, intVal),
532+
cir::IntAttr::get(condType, endVal)});
533+
kind = cir::CaseOpKind::Range;
534+
535+
// We don't currently fold case range statements with other case statements.
536+
// TODO(cir): Add this capability. Folding these cases is going to be
537+
// implemented in CIRSimplify when it is upstreamed.
538+
assert(!cir::MissingFeatures::foldRangeCase());
539+
assert(!cir::MissingFeatures::foldCascadingCases());
540+
} else {
541+
value = builder.getArrayAttr({cir::IntAttr::get(condType, intVal)});
542+
kind = cir::CaseOpKind::Equal;
528543
}
529-
assert(!cir::MissingFeatures::foldCaseStmt());
530-
return emitCaseDefaultCascade(&s, condType, value, cir::CaseOpKind::Equal,
544+
545+
return emitCaseDefaultCascade(&s, condType, value, kind,
531546
buildingTopLevelCase);
532547
}
533548

549+
mlir::LogicalResult CIRGenFunction::emitDefaultStmt(const clang::DefaultStmt &s,
550+
mlir::Type condType,
551+
bool buildingTopLevelCase) {
552+
return emitCaseDefaultCascade(&s, condType, builder.getArrayAttr({}),
553+
cir::CaseOpKind::Default, buildingTopLevelCase);
554+
}
555+
534556
mlir::LogicalResult CIRGenFunction::emitSwitchCase(const SwitchCase &s,
535557
bool buildingTopLevelCase) {
536558
assert(!condTypeStack.empty() &&
@@ -540,10 +562,9 @@ mlir::LogicalResult CIRGenFunction::emitSwitchCase(const SwitchCase &s,
540562
return emitCaseStmt(cast<CaseStmt>(s), condTypeStack.back(),
541563
buildingTopLevelCase);
542564

543-
if (s.getStmtClass() == Stmt::DefaultStmtClass) {
544-
getCIRGenModule().errorNYI(s.getSourceRange(), "Default case");
545-
return mlir::failure();
546-
}
565+
if (s.getStmtClass() == Stmt::DefaultStmtClass)
566+
return emitDefaultStmt(cast<DefaultStmt>(s), condTypeStack.back(),
567+
buildingTopLevelCase);
547568

548569
llvm_unreachable("expect case or default stmt");
549570
}

0 commit comments

Comments
 (0)