Skip to content

Commit 5eac58e

Browse files
committed
[AST] SwitchStmt only hold CaseStmt
Now that there is no way SwitchStmt to hold AST nodes other than CaseStmt.
1 parent 002d7d7 commit 5eac58e

19 files changed

+63
-114
lines changed

include/swift/AST/Stmt.h

Lines changed: 9 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1408,7 +1408,7 @@ class CaseStmt final
14081408

14091409
/// Switch statement.
14101410
class SwitchStmt final : public LabeledStmt,
1411-
private llvm::TrailingObjects<SwitchStmt, ASTNode> {
1411+
private llvm::TrailingObjects<SwitchStmt, CaseStmt *> {
14121412
friend TrailingObjects;
14131413

14141414
SourceLoc SwitchLoc, LBraceLoc, RBraceLoc;
@@ -1431,16 +1431,13 @@ class SwitchStmt final : public LabeledStmt,
14311431
public:
14321432
/// Allocate a new SwitchStmt in the given ASTContext.
14331433
static SwitchStmt *create(LabeledStmtInfo LabelInfo, SourceLoc SwitchLoc,
1434-
Expr *SubjectExpr,
1435-
SourceLoc LBraceLoc,
1436-
ArrayRef<ASTNode> Cases,
1437-
SourceLoc RBraceLoc,
1438-
SourceLoc EndLoc,
1439-
ASTContext &C);
1434+
Expr *SubjectExpr, SourceLoc LBraceLoc,
1435+
ArrayRef<CaseStmt *> Cases, SourceLoc RBraceLoc,
1436+
SourceLoc EndLoc, ASTContext &C);
14401437

14411438
static SwitchStmt *createImplicit(LabeledStmtInfo LabelInfo,
1442-
Expr *SubjectExpr, ArrayRef<ASTNode> Cases,
1443-
ASTContext &C) {
1439+
Expr *SubjectExpr,
1440+
ArrayRef<CaseStmt *> Cases, ASTContext &C) {
14441441
return SwitchStmt::create(LabelInfo, /*SwitchLoc=*/SourceLoc(), SubjectExpr,
14451442
/*LBraceLoc=*/SourceLoc(), Cases,
14461443
/*RBraceLoc=*/SourceLoc(), /*EndLoc=*/SourceLoc(),
@@ -1463,27 +1460,10 @@ class SwitchStmt final : public LabeledStmt,
14631460
Expr *getSubjectExpr() const { return SubjectExpr; }
14641461
void setSubjectExpr(Expr *e) { SubjectExpr = e; }
14651462

1466-
ArrayRef<ASTNode> getRawCases() const {
1467-
return {getTrailingObjects<ASTNode>(), static_cast<size_t>(Bits.SwitchStmt.CaseCount)};
1468-
}
1469-
1470-
private:
1471-
struct AsCaseStmtWithSkippingNonCaseStmts {
1472-
AsCaseStmtWithSkippingNonCaseStmts() {}
1473-
std::optional<CaseStmt *> operator()(const ASTNode &N) const {
1474-
if (auto *CS = llvm::dyn_cast_or_null<CaseStmt>(N.dyn_cast<Stmt*>()))
1475-
return CS;
1476-
return std::nullopt;
1477-
}
1478-
};
1479-
1480-
public:
1481-
using AsCaseStmtRange = OptionalTransformRange<ArrayRef<ASTNode>,
1482-
AsCaseStmtWithSkippingNonCaseStmts>;
1483-
14841463
/// Get the list of case clauses.
1485-
AsCaseStmtRange getCases() const {
1486-
return AsCaseStmtRange(getRawCases(), AsCaseStmtWithSkippingNonCaseStmts());
1464+
ArrayRef<CaseStmt *> getCases() const {
1465+
return {getTrailingObjects<CaseStmt *>(),
1466+
static_cast<size_t>(Bits.SwitchStmt.CaseCount)};
14871467
}
14881468

14891469
/// Retrieve the complete set of branches for this switch statement.

include/swift/Parse/Parser.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1969,7 +1969,8 @@ class Parser {
19691969
ParserResult<CaseStmt> parseStmtCatch();
19701970
ParserResult<Stmt> parseStmtForEach(LabeledStmtInfo LabelInfo);
19711971
ParserResult<Stmt> parseStmtSwitch(LabeledStmtInfo LabelInfo);
1972-
ParserStatus parseStmtCases(SmallVectorImpl<ASTNode> &cases, bool IsActive);
1972+
ParserStatus parseStmtCases(SmallVectorImpl<CaseStmt *> &cases,
1973+
bool IsActive);
19731974
ParserResult<CaseStmt> parseStmtCase(bool IsActive);
19741975
ParserResult<Stmt> parseStmtPoundAssert();
19751976

lib/AST/ASTDumper.cpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3123,12 +3123,9 @@ class PrintStmt : public StmtVisitor<PrintStmt, void, Label>,
31233123
void visitSwitchStmt(SwitchStmt *S, Label label) {
31243124
printCommon(S, "switch_stmt", label);
31253125
printRec(S->getSubjectExpr(), Label::optional("subject_expr"));
3126-
printList(S->getRawCases(), [&](ASTNode N, Label label) {
3127-
if (N.is<Stmt*>())
3128-
printRec(N.get<Stmt*>(), label);
3129-
else
3130-
printRec(N.get<Decl*>(), label);
3131-
}, Label::optional("cases"));
3126+
printList(
3127+
S->getCases(), [&](CaseStmt *CS, Label label) { printRec(CS, label); },
3128+
Label::optional("cases"));
31323129
printFoot();
31333130
}
31343131
void visitCaseStmt(CaseStmt *S, Label label) {

lib/AST/ASTPrinter.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5657,9 +5657,8 @@ void PrintAST::visitSwitchStmt(SwitchStmt *stmt) {
56575657
visit(stmt->getSubjectExpr());
56585658
Printer << " {";
56595659
Printer.printNewline();
5660-
for (auto N : stmt->getRawCases()) {
5661-
if (N.is<Stmt*>())
5662-
visit(cast<CaseStmt>(N.get<Stmt*>()));
5660+
for (auto *CS : stmt->getCases()) {
5661+
visit(CS);
56635662
Printer.printNewline();
56645663
}
56655664
indent();

lib/AST/ASTWalker.cpp

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2063,15 +2063,12 @@ Stmt *Traversal::visitSwitchStmt(SwitchStmt *S) {
20632063
else
20642064
return nullptr;
20652065

2066-
for (auto N : S->getRawCases()) {
2067-
if (Stmt *aCase = N.dyn_cast<Stmt*>()) {
2068-
assert(isa<CaseStmt>(aCase));
2069-
if (Stmt *aStmt = doIt(aCase)) {
2070-
assert(aCase == aStmt && "switch case remap not supported");
2071-
(void)aStmt;
2072-
} else
2073-
return nullptr;
2074-
}
2066+
for (auto *aCase : S->getCases()) {
2067+
if (Stmt *aStmt = doIt(aCase)) {
2068+
assert(aCase == aStmt && "switch case remap not supported");
2069+
(void)aStmt;
2070+
} else
2071+
return nullptr;
20752072
}
20762073

20772074
return S;

lib/AST/Bridging/StmtBridging.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -251,9 +251,9 @@ BridgedSwitchStmt BridgedSwitchStmt_createParsed(
251251
BridgedSourceLoc cLBraceLoc, BridgedArrayRef cCases,
252252
BridgedSourceLoc cRBraceLoc) {
253253
auto &context = cContext.unbridged();
254-
auto cases =
255-
context.AllocateTransform<ASTNode>(cCases.unbridged<BridgedASTNode>(),
256-
[](auto &e) { return e.unbridged(); });
254+
SmallVector<CaseStmt *, 16> cases;
255+
for (auto cCase : cCases.unbridged<BridgedCaseStmt>())
256+
cases.push_back(cCase.unbridged());
257257
return SwitchStmt::create(cLabelInfo.unbridged(), cSwitchLoc.unbridged(),
258258
cSubjectExpr.unbridged(), cLBraceLoc.unbridged(),
259259
cases, cRBraceLoc.unbridged(),

lib/AST/Stmt.cpp

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -940,26 +940,18 @@ CaseStmt *CaseStmt::findNextCaseStmt() const {
940940
}
941941

942942
SwitchStmt *SwitchStmt::create(LabeledStmtInfo LabelInfo, SourceLoc SwitchLoc,
943-
Expr *SubjectExpr,
944-
SourceLoc LBraceLoc,
945-
ArrayRef<ASTNode> Cases,
946-
SourceLoc RBraceLoc,
947-
SourceLoc EndLoc,
948-
ASTContext &C) {
949-
#ifndef NDEBUG
950-
for (auto N : Cases)
951-
assert(N.is<Stmt*>() && isa<CaseStmt>(N.get<Stmt*>()));
952-
#endif
953-
954-
void *p = C.Allocate(totalSizeToAlloc<ASTNode>(Cases.size()),
943+
Expr *SubjectExpr, SourceLoc LBraceLoc,
944+
ArrayRef<CaseStmt *> Cases, SourceLoc RBraceLoc,
945+
SourceLoc EndLoc, ASTContext &C) {
946+
void *p = C.Allocate(totalSizeToAlloc<CaseStmt *>(Cases.size()),
955947
alignof(SwitchStmt));
956948
SwitchStmt *theSwitch = ::new (p) SwitchStmt(LabelInfo, SwitchLoc,
957949
SubjectExpr, LBraceLoc,
958950
Cases.size(), RBraceLoc,
959951
EndLoc);
960952

961953
std::uninitialized_copy(Cases.begin(), Cases.end(),
962-
theSwitch->getTrailingObjects<ASTNode>());
954+
theSwitch->getTrailingObjects<CaseStmt *>());
963955
for (auto *caseStmt : theSwitch->getCases())
964956
caseStmt->setParentStmt(theSwitch);
965957

lib/ASTGen/Sources/ASTGen/Stmts.swift

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -502,7 +502,7 @@ extension ASTGenVisitor {
502502
}
503503

504504
func generate(switchCaseList node: SwitchCaseListSyntax) -> BridgedArrayRef {
505-
var allBridgedCases: [BridgedASTNode] = []
505+
var allBridgedCases: [BridgedCaseStmt] = []
506506
visitIfConfigElements(node, of: SwitchCaseSyntax.self) { element in
507507
switch element {
508508
case .ifConfigDecl(let ifConfigDecl):
@@ -511,11 +511,11 @@ extension ASTGenVisitor {
511511
return .underlying(switchCase)
512512
}
513513
} body: { caseNode in
514-
allBridgedCases.append(
515-
.stmt(self.generate(switchCase: caseNode).asStmt)
516-
)
514+
allBridgedCases.append(self.generate(switchCase: caseNode))
517515
}
518516

517+
// TODO: Diagnose 'case' after 'default'.
518+
519519
return allBridgedCases.lazy.bridgedArray(in: self)
520520
}
521521

lib/Parse/ParseStmt.cpp

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2529,7 +2529,7 @@ ParserResult<Stmt> Parser::parseStmtSwitch(LabeledStmtInfo LabelInfo) {
25292529

25302530
SourceLoc lBraceLoc;
25312531
SourceLoc rBraceLoc;
2532-
SmallVector<ASTNode, 8> cases;
2532+
SmallVector<CaseStmt *, 8> cases;
25332533

25342534
if (Status.isErrorOrHasCompletion()) {
25352535
return makeParserResult(
@@ -2551,9 +2551,7 @@ ParserResult<Stmt> Parser::parseStmtSwitch(LabeledStmtInfo LabelInfo) {
25512551
// We cannot have additional cases after a default clause. Complain on
25522552
// the first offender.
25532553
bool hasDefault = false;
2554-
for (auto Element : cases) {
2555-
if (!Element.is<Stmt*>()) continue;
2556-
auto *CS = cast<CaseStmt>(Element.get<Stmt*>());
2554+
for (auto *CS : cases) {
25572555
if (hasDefault) {
25582556
diagnose(CS->getLoc(), diag::case_after_default);
25592557
break;
@@ -2572,8 +2570,8 @@ ParserResult<Stmt> Parser::parseStmtSwitch(LabeledStmtInfo LabelInfo) {
25722570
/*EndLoc=*/rBraceLoc, Context));
25732571
}
25742572

2575-
ParserStatus
2576-
Parser::parseStmtCases(SmallVectorImpl<ASTNode> &cases, bool IsActive) {
2573+
ParserStatus Parser::parseStmtCases(SmallVectorImpl<CaseStmt *> &cases,
2574+
bool IsActive) {
25772575
ParserStatus Status;
25782576
while (Tok.isNot(tok::r_brace, tok::eof,
25792577
tok::pound_endif, tok::pound_elseif, tok::pound_else)) {
@@ -2586,15 +2584,14 @@ Parser::parseStmtCases(SmallVectorImpl<ASTNode> &cases, bool IsActive) {
25862584
// '#if' in 'case' position can enclose one or more 'case' or 'default'
25872585
// clauses.
25882586
auto IfConfigResult =
2589-
parseIfConfig(IfConfigContext::SwitchStmt,
2590-
[&](bool IsActive) {
2591-
SmallVector<ASTNode, 16> elements;
2592-
parseStmtCases(elements, IsActive);
2593-
2594-
if (IsActive) {
2595-
cases.append(elements);
2596-
}
2597-
});
2587+
parseIfConfig(IfConfigContext::SwitchStmt, [&](bool IsActive) {
2588+
SmallVector<CaseStmt *, 16> elements;
2589+
parseStmtCases(elements, IsActive);
2590+
2591+
if (IsActive) {
2592+
cases.append(elements);
2593+
}
2594+
});
25982595
Status |= IfConfigResult;
25992596
} else if (Tok.is(tok::pound_warning) || Tok.is(tok::pound_error)) {
26002597
Status |= parseDeclPoundDiagnostic();

lib/SILGen/SILGenPattern.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3455,7 +3455,7 @@ void SILGenFunction::emitSwitchStmt(SwitchStmt *S) {
34553455

34563456
// Add a row for each label of each case.
34573457
SmallVector<ClauseRow, 8> clauseRows;
3458-
clauseRows.reserve(S->getRawCases().size());
3458+
clauseRows.reserve(S->getCases().size());
34593459
bool hasFallthrough = false;
34603460
for (auto caseBlock : S->getCases()) {
34613461
// If the previous block falls through into this block or we have multiple

lib/Sema/BuilderTransform.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -585,7 +585,7 @@ class ResultBuilderTransform
585585
// type-checked first.
586586
SmallVector<ASTNode, 4> doBody;
587587

588-
SmallVector<ASTNode, 4> cases;
588+
SmallVector<CaseStmt *, 4> cases;
589589
SmallVector<Expr *, 4> caseVarRefs;
590590

591591
for (auto *caseStmt : switchStmt->getCases()) {

lib/Sema/CSSyntacticElement.cpp

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1059,8 +1059,8 @@ class SyntacticElementConstraintGenerator
10591059
cs.setTargetFor(switchStmt, target);
10601060
}
10611061

1062-
for (auto rawCase : switchStmt->getRawCases())
1063-
elements.push_back(makeElement(rawCase, locator));
1062+
for (auto &CS : switchStmt->getCases())
1063+
elements.push_back(makeElement(CS, locator));
10641064
}
10651065

10661066
createConjunction(elements, locator);
@@ -1963,18 +1963,12 @@ class SyntacticElementSolutionApplication
19631963

19641964
// Visit the raw cases.
19651965
bool limitExhaustivityChecks = false;
1966-
for (auto rawCase : switchStmt->getRawCases()) {
1967-
if (auto decl = rawCase.dyn_cast<Decl *>()) {
1968-
visitDecl(decl);
1969-
continue;
1970-
}
1971-
1972-
auto caseStmt = cast<CaseStmt>(rawCase.get<Stmt *>());
1966+
for (auto *CS : switchStmt->getCases()) {
19731967
// Body of the `case` statement can contain a `fallthrough`
19741968
// statement that requires both source and destination
19751969
// `case` preambles to be type-checked, so bodies of `case`
19761970
// statements should be visited after preambles.
1977-
visitCaseStmtPreamble(caseStmt);
1971+
visitCaseStmtPreamble(CS);
19781972
}
19791973

19801974
for (auto *caseStmt : switchStmt->getCases()) {

lib/Sema/DerivedConformanceCodable.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -931,7 +931,7 @@ createEnumSwitch(ASTContext &C, DeclContext *DC, Expr *expr, EnumDecl *enumDecl,
931931
std::function<std::tuple<EnumElementDecl *, BraceStmt *>(
932932
EnumElementDecl *, EnumElementDecl *, ArrayRef<VarDecl *>)>
933933
createCase) {
934-
SmallVector<ASTNode, 4> cases;
934+
SmallVector<CaseStmt *, 4> cases;
935935
for (auto elt : enumDecl->getAllElements()) {
936936
// .<elt>(let a0, let a1, ...)
937937
SmallVector<VarDecl *, 3> payloadVars;

lib/Sema/DerivedConformanceCodingKey.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ deriveBodyCodingKey_enum_stringValue(AbstractFunctionDecl *strValDecl, void *) {
209209
body = BraceStmt::create(C, SourceLoc(), ASTNode(returnStmt),
210210
SourceLoc());
211211
} else {
212-
SmallVector<ASTNode, 4> cases;
212+
SmallVector<CaseStmt *, 4> cases;
213213
for (auto *elt : elements) {
214214
auto *pat = EnumElementPattern::createImplicit(enumType, elt,
215215
/*subPattern*/ nullptr,
@@ -270,7 +270,7 @@ deriveBodyCodingKey_init_stringValue(AbstractFunctionDecl *initDecl, void *) {
270270
}
271271

272272
auto *selfRef = DerivedConformance::createSelfDeclRef(initDecl);
273-
SmallVector<ASTNode, 4> cases;
273+
SmallVector<CaseStmt *, 4> cases;
274274
for (auto *elt : elements) {
275275
// Skip the cases that would return unavailable elements since those can't
276276
// be instantiated at runtime.

lib/Sema/DerivedConformanceComparable.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ deriveBodyComparable_enum_hasAssociatedValues_lt(AbstractFunctionDecl *ltDecl, v
103103
auto enumDecl = cast<EnumDecl>(aParam->getInterfaceType()->getAnyNominal());
104104

105105
SmallVector<ASTNode, 8> statements;
106-
SmallVector<ASTNode, 4> cases;
106+
SmallVector<CaseStmt *, 4> cases;
107107
unsigned elementCount = 0; // need this as `getAllElements` returns a generator
108108

109109
// For each enum element, generate a case statement matching a pair containing

lib/Sema/DerivedConformanceEquatableHashable.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ deriveBodyEquatable_enum_uninhabited_eq(AbstractFunctionDecl *eqDecl, void *) {
6969
assert(!cast<EnumDecl>(aParam->getInterfaceType()->getAnyNominal())->hasCases());
7070

7171
SmallVector<ASTNode, 1> statements;
72-
SmallVector<ASTNode, 0> cases;
72+
SmallVector<CaseStmt *, 0> cases;
7373

7474
// switch (a, b) { }
7575
auto aRef = new (C) DeclRefExpr(aParam, DeclNameLoc(), /*implicit*/ true,
@@ -163,7 +163,7 @@ deriveBodyEquatable_enum_hasAssociatedValues_eq(AbstractFunctionDecl *eqDecl,
163163
auto enumDecl = cast<EnumDecl>(aParam->getInterfaceType()->getAnyNominal());
164164

165165
SmallVector<ASTNode, 6> statements;
166-
SmallVector<ASTNode, 4> cases;
166+
SmallVector<CaseStmt *, 4> cases;
167167
unsigned elementCount = 0;
168168

169169
// For each enum element, generate a case statement matching a pair containing
@@ -687,7 +687,7 @@ deriveBodyHashable_enum_hasAssociatedValues_hashInto(
687687
auto hasherParam = hashIntoDecl->getParameters()->get(0);
688688

689689
unsigned index = 0;
690-
SmallVector<ASTNode, 4> cases;
690+
SmallVector<CaseStmt *, 4> cases;
691691

692692
// For each enum element, generate a case statement that binds the associated
693693
// values so that they can be fed to the hasher.

lib/Sema/DerivedConformanceRawRepresentable.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ deriveBodyRawRepresentable_raw(AbstractFunctionDecl *toRawDecl, void *) {
114114

115115
Type enumType = parentDC->getDeclaredTypeInContext();
116116

117-
SmallVector<ASTNode, 4> cases;
117+
SmallVector<CaseStmt *, 4> cases;
118118
for (auto elt : enumDecl->getAllElements()) {
119119
auto *pat = EnumElementPattern::createImplicit(
120120
enumType, elt, /*subPattern*/ nullptr, /*DC*/ toRawDecl);
@@ -311,8 +311,8 @@ deriveBodyRawRepresentable_init(AbstractFunctionDecl *initDecl, void *) {
311311
Type enumType = parentDC->getDeclaredTypeInContext();
312312

313313
auto selfDecl = cast<ConstructorDecl>(initDecl)->getImplicitSelfDecl();
314-
315-
SmallVector<ASTNode, 4> cases;
314+
315+
SmallVector<CaseStmt *, 4> cases;
316316
unsigned Idx = 0;
317317
for (auto elt : enumDecl->getAllElements()) {
318318
// First, check case availability. If the case will definitely be

lib/Sema/DerivedConformances.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -771,7 +771,7 @@ DeclRefExpr *DerivedConformance::convertEnumToIndex(SmallVectorImpl<ASTNode> &st
771771
C, StaticSpellingKind::None, indexPat, /*InitExpr*/ nullptr, funcDecl);
772772

773773
unsigned index = 0;
774-
SmallVector<ASTNode, 4> cases;
774+
SmallVector<CaseStmt *, 4> cases;
775775
for (auto elt : enumDecl->getAllElements()) {
776776
if (auto *unavailableElementCase =
777777
DerivedConformance::unavailableEnumElementCaseStmt(enumType, elt,

0 commit comments

Comments
 (0)