Skip to content

Commit 5d478bd

Browse files
committed
[Parse] Allow #if to guard switch case clauses
Resolves: https://bugs.swift.org/browse/SR-4196 https://bugs.swift.org/browse/SR-2
1 parent 2f82acc commit 5d478bd

16 files changed

+474
-76
lines changed

include/swift/AST/Stmt.h

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -932,7 +932,7 @@ class CaseStmt final : public Stmt,
932932

933933
/// Switch statement.
934934
class SwitchStmt final : public LabeledStmt,
935-
private llvm::TrailingObjects<SwitchStmt, CaseStmt *> {
935+
private llvm::TrailingObjects<SwitchStmt, ASTNode> {
936936
friend TrailingObjects;
937937

938938
SourceLoc SwitchLoc, LBraceLoc, RBraceLoc;
@@ -953,7 +953,7 @@ class SwitchStmt final : public LabeledStmt,
953953
static SwitchStmt *create(LabeledStmtInfo LabelInfo, SourceLoc SwitchLoc,
954954
Expr *SubjectExpr,
955955
SourceLoc LBraceLoc,
956-
ArrayRef<CaseStmt*> Cases,
956+
ArrayRef<ASTNode> Cases,
957957
SourceLoc RBraceLoc,
958958
ASTContext &C);
959959

@@ -972,10 +972,28 @@ class SwitchStmt final : public LabeledStmt,
972972
/// Get the subject expression of the switch.
973973
Expr *getSubjectExpr() const { return SubjectExpr; }
974974
void setSubjectExpr(Expr *e) { SubjectExpr = e; }
975+
976+
ArrayRef<ASTNode> getRawCases() const {
977+
return {getTrailingObjects<ASTNode>(), CaseCount};
978+
}
979+
980+
private:
981+
struct AsCaseStmtWithSkippingIfConfig {
982+
AsCaseStmtWithSkippingIfConfig() {}
983+
Optional<CaseStmt*> operator()(const ASTNode &N) const {
984+
if (auto *CS = llvm::dyn_cast_or_null<CaseStmt>(N.dyn_cast<Stmt*>()))
985+
return CS;
986+
return None;
987+
}
988+
};
989+
990+
public:
991+
using AsCaseStmtRange = OptionalTransformRange<ArrayRef<ASTNode>,
992+
AsCaseStmtWithSkippingIfConfig>;
975993

976994
/// Get the list of case clauses.
977-
ArrayRef<CaseStmt*> getCases() const {
978-
return {getTrailingObjects<CaseStmt*>(), CaseCount};
995+
AsCaseStmtRange getCases() const {
996+
return AsCaseStmtRange(getRawCases(), AsCaseStmtWithSkippingIfConfig());
979997
}
980998

981999
static bool classof(const Stmt *S) {

include/swift/Parse/Parser.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1269,6 +1269,8 @@ class Parser {
12691269
// Statement Parsing
12701270

12711271
bool isStartOfStmt();
1272+
bool isTerminatorForBraceItemListKind(BraceItemListKind Kind,
1273+
ArrayRef<ASTNode> ParsedDecls);
12721274
ParserResult<Stmt> parseStmt();
12731275
ParserStatus parseExprOrStmt(ASTNode &Result);
12741276
ParserResult<Stmt> parseStmtBreak();
@@ -1291,7 +1293,8 @@ class Parser {
12911293
ParserResult<Stmt> parseStmtForEach(SourceLoc ForLoc,
12921294
LabeledStmtInfo LabelInfo);
12931295
ParserResult<Stmt> parseStmtSwitch(LabeledStmtInfo LabelInfo);
1294-
ParserResult<CaseStmt> parseStmtCase();
1296+
ParserStatus parseStmtCases(SmallVectorImpl<ASTNode> &cases, bool IsActive);
1297+
ParserResult<CaseStmt> parseStmtCase(bool IsActive);
12951298

12961299
//===--------------------------------------------------------------------===//
12971300
// Generics Parsing

lib/AST/ASTDumper.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1522,9 +1522,12 @@ class PrintStmt : public StmtVisitor<PrintStmt> {
15221522
void visitSwitchStmt(SwitchStmt *S) {
15231523
printCommon(S, "switch_stmt") << '\n';
15241524
printRec(S->getSubjectExpr());
1525-
for (CaseStmt *C : S->getCases()) {
1525+
for (auto N : S->getRawCases()) {
15261526
OS << '\n';
1527-
printRec(C);
1527+
if (N.is<Stmt*>())
1528+
printRec(N.get<Stmt*>());
1529+
else
1530+
printRec(N.get<Decl*>());
15281531
}
15291532
PrintWithColorRAII(OS, ParenthesisColor) << ')';
15301533
}

lib/AST/ASTPrinter.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3359,8 +3359,11 @@ void PrintAST::visitSwitchStmt(SwitchStmt *stmt) {
33593359
// FIXME: print subject
33603360
Printer << "{";
33613361
Printer.printNewline();
3362-
for (CaseStmt *C : stmt->getCases()) {
3363-
visit(C);
3362+
for (auto N : stmt->getRawCases()) {
3363+
if (N.is<Stmt*>())
3364+
visit(cast<CaseStmt>(N.get<Stmt*>()));
3365+
else
3366+
visit(cast<IfConfigDecl>(N.get<Decl*>()));
33643367
}
33653368
Printer.printNewline();
33663369
indent();

lib/AST/ASTWalker.cpp

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1452,12 +1452,19 @@ Stmt *Traversal::visitSwitchStmt(SwitchStmt *S) {
14521452
else
14531453
return nullptr;
14541454

1455-
for (CaseStmt *aCase : S->getCases()) {
1456-
if (Stmt *aStmt = doIt(aCase)) {
1457-
assert(aCase == aStmt && "switch case remap not supported");
1458-
(void)aStmt;
1459-
} else
1460-
return nullptr;
1455+
for (auto N : S->getRawCases()) {
1456+
if (Stmt *aCase = N.dyn_cast<Stmt*>()) {
1457+
assert(isa<CaseStmt>(aCase));
1458+
if (Stmt *aStmt = doIt(aCase)) {
1459+
assert(aCase == aStmt && "switch case remap not supported");
1460+
(void)aStmt;
1461+
} else
1462+
return nullptr;
1463+
} else {
1464+
assert(isa<IfConfigDecl>(N.get<Decl*>()));
1465+
if (doIt(N.get<Decl*>()))
1466+
return nullptr;
1467+
}
14611468
}
14621469

14631470
return S;

lib/AST/Stmt.cpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -412,15 +412,22 @@ CaseStmt *CaseStmt::create(ASTContext &C, SourceLoc CaseLoc,
412412
SwitchStmt *SwitchStmt::create(LabeledStmtInfo LabelInfo, SourceLoc SwitchLoc,
413413
Expr *SubjectExpr,
414414
SourceLoc LBraceLoc,
415-
ArrayRef<CaseStmt *> Cases,
415+
ArrayRef<ASTNode> Cases,
416416
SourceLoc RBraceLoc,
417417
ASTContext &C) {
418-
void *p = C.Allocate(totalSizeToAlloc<CaseStmt *>(Cases.size()),
418+
#ifndef NDEBUG
419+
for (auto N : Cases)
420+
assert((N.is<Stmt*>() && isa<CaseStmt>(N.get<Stmt*>())) ||
421+
(N.is<Decl*>() && isa<IfConfigDecl>(N.get<Decl*>())));
422+
#endif
423+
424+
void *p = C.Allocate(totalSizeToAlloc<ASTNode>(Cases.size()),
419425
alignof(SwitchStmt));
420426
SwitchStmt *theSwitch = ::new (p) SwitchStmt(LabelInfo, SwitchLoc,
421427
SubjectExpr, LBraceLoc,
422428
Cases.size(), RBraceLoc);
429+
423430
std::uninitialized_copy(Cases.begin(), Cases.end(),
424-
theSwitch->getTrailingObjects<CaseStmt *>());
431+
theSwitch->getTrailingObjects<ASTNode>());
425432
return theSwitch;
426433
}

lib/Parse/ParseStmt.cpp

Lines changed: 76 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -130,14 +130,25 @@ ParserStatus Parser::parseExprOrStmt(ASTNode &Result) {
130130
return ResultExpr;
131131
}
132132

133-
static bool isTerminatorForBraceItemListKind(const Token &Tok,
134-
BraceItemListKind Kind,
135-
ArrayRef<ASTNode> ParsedDecls) {
133+
bool Parser::isTerminatorForBraceItemListKind(BraceItemListKind Kind,
134+
ArrayRef<ASTNode> ParsedDecls) {
136135
switch (Kind) {
137136
case BraceItemListKind::Brace:
138137
return false;
139138
case BraceItemListKind::Case:
140-
return Tok.is(tok::kw_case) || Tok.is(tok::kw_default);
139+
if (Tok.is(tok::pound_if)) {
140+
// '#if' here could be to guard 'case:' or statements in cases.
141+
// If the next non-directive line starts with 'case' or 'default', it is
142+
// for 'case's.
143+
Parser::BacktrackingScope Backtrack(*this);
144+
do {
145+
consumeToken();
146+
while (!Tok.isAtStartOfLine() && Tok.isNot(tok::eof))
147+
skipSingle();
148+
} while (Tok.isAny(tok::pound_if, tok::pound_elseif, tok::pound_else));
149+
return Tok.isAny(tok::kw_case, tok::kw_default);
150+
}
151+
return Tok.isAny(tok::kw_case, tok::kw_default);
141152
case BraceItemListKind::TopLevelCode:
142153
// When parsing the top level executable code for a module, if we parsed
143154
// some executable code, then we're done. We want to process (name bind,
@@ -247,7 +258,7 @@ ParserStatus Parser::parseBraceItems(SmallVectorImpl<ASTNode> &Entries,
247258
Tok.isNot(tok::kw_sil_witness_table) &&
248259
Tok.isNot(tok::kw_sil_default_witness_table) &&
249260
(isConditionalBlock ||
250-
!isTerminatorForBraceItemListKind(Tok, Kind, Entries))) {
261+
!isTerminatorForBraceItemListKind(Kind, Entries))) {
251262
if (Kind == BraceItemListKind::TopLevelLibrary &&
252263
skipExtraTopLevelRBraces())
253264
continue;
@@ -2151,36 +2162,20 @@ ParserResult<Stmt> Parser::parseStmtSwitch(LabeledStmtInfo LabelInfo) {
21512162
SourceLoc lBraceLoc = consumeToken(tok::l_brace);
21522163
SourceLoc rBraceLoc;
21532164

2154-
// If there are non-case-label statements at the start of the switch body,
2155-
// raise an error and recover by discarding them.
2156-
bool DiagnosedNotCoveredStmt = false;
2157-
while (!Tok.is(tok::kw_case) && !Tok.is(tok::kw_default)
2158-
&& !Tok.is(tok::r_brace) && !Tok.is(tok::eof)) {
2159-
if (!DiagnosedNotCoveredStmt) {
2160-
diagnose(Tok, diag::stmt_in_switch_not_covered_by_case);
2161-
DiagnosedNotCoveredStmt = true;
2162-
}
2163-
skipSingle();
2164-
}
2165-
2166-
SmallVector<CaseStmt*, 8> cases;
2167-
bool parsedDefault = false;
2168-
bool parsedBlockAfterDefault = false;
2169-
while (Tok.is(tok::kw_case) || Tok.is(tok::kw_default)) {
2170-
// We cannot have additional cases after a default clause. Complain on
2171-
// the first offender.
2172-
if (parsedDefault && !parsedBlockAfterDefault) {
2173-
parsedBlockAfterDefault = true;
2174-
diagnose(Tok, diag::case_after_default);
2175-
}
2176-
2177-
ParserResult<CaseStmt> Case = parseStmtCase();
2178-
Status |= Case;
2179-
if (Case.isNonNull()) {
2180-
cases.push_back(Case.get());
2181-
if (Case.get()->isDefault())
2182-
parsedDefault = true;
2165+
SmallVector<ASTNode, 8> cases;
2166+
Status |= parseStmtCases(cases, /*IsActive=*/true);
2167+
2168+
// We cannot have additional cases after a default clause. Complain on
2169+
// the first offender.
2170+
bool hasDefault = false;
2171+
for (auto Element : cases) {
2172+
if (!Element.is<Stmt*>()) continue;
2173+
auto *CS = cast<CaseStmt>(Element.get<Stmt*>());
2174+
if (hasDefault) {
2175+
diagnose(CS->getLoc(), diag::case_after_default);
2176+
break;
21832177
}
2178+
hasDefault |= CS->isDefault();
21842179
}
21852180

21862181
if (parseMatchingToken(tok::r_brace, rBraceLoc,
@@ -2193,6 +2188,51 @@ ParserResult<Stmt> Parser::parseStmtSwitch(LabeledStmtInfo LabelInfo) {
21932188
lBraceLoc, cases, rBraceLoc, Context));
21942189
}
21952190

2191+
ParserStatus
2192+
Parser::parseStmtCases(SmallVectorImpl<ASTNode> &cases, bool IsActive) {
2193+
ParserStatus Status;
2194+
while (Tok.isNot(tok::r_brace, tok::eof,
2195+
tok::pound_endif, tok::pound_elseif, tok::pound_else)) {
2196+
if (Tok.isAny(tok::kw_case, tok::kw_default)) {
2197+
ParserResult<CaseStmt> Case = parseStmtCase(IsActive);
2198+
Status |= Case;
2199+
if (Case.isNonNull())
2200+
cases.emplace_back(Case.get());
2201+
} else if (Tok.is(tok::pound_if)) {
2202+
// '#if' in 'case' position can enclose one or more 'case' or 'default'
2203+
// clauses.
2204+
auto IfConfigResult = parseIfConfig(
2205+
[&](SmallVectorImpl<ASTNode> &Elements, bool IsActive) {
2206+
parseStmtCases(Elements, IsActive);
2207+
});
2208+
Status |= IfConfigResult;
2209+
if (auto ICD = IfConfigResult.getPtrOrNull()) {
2210+
cases.emplace_back(ICD);
2211+
2212+
for (auto &Entry : ICD->getActiveClauseElements()) {
2213+
if (Entry.is<Decl*>() && isa<IfConfigDecl>(Entry.get<Decl*>()))
2214+
// Don't hoist nested '#if'.
2215+
continue;
2216+
2217+
assert(Entry.is<Stmt*>() && isa<CaseStmt>(Entry.get<Stmt*>()));
2218+
cases.push_back(Entry);
2219+
}
2220+
}
2221+
} else {
2222+
// If there are non-case-label statements at the start of the switch body,
2223+
// raise an error and recover by discarding them.
2224+
diagnose(Tok, diag::stmt_in_switch_not_covered_by_case);
2225+
2226+
while (Tok.isNot(tok::r_brace, tok::eof, tok::pound_elseif,
2227+
tok::pound_else, tok::pound_endif) &&
2228+
!isTerminatorForBraceItemListKind(BraceItemListKind::Case, {})) {
2229+
skipSingle();
2230+
}
2231+
}
2232+
}
2233+
return Status;
2234+
}
2235+
21962236
static ParserStatus parseStmtCase(Parser &P, SourceLoc &CaseLoc,
21972237
SmallVectorImpl<CaseLabelItem> &LabelItems,
21982238
SmallVectorImpl<VarDecl *> &BoundDecls,
@@ -2257,9 +2297,9 @@ parseStmtCaseDefault(Parser &P, SourceLoc &CaseLoc,
22572297
return Status;
22582298
}
22592299

2260-
ParserResult<CaseStmt> Parser::parseStmtCase() {
2300+
ParserResult<CaseStmt> Parser::parseStmtCase(bool IsActive) {
22612301
// A case block has its own scope for variables bound out of the pattern.
2262-
Scope S(this, ScopeKind::CaseVars);
2302+
Scope S(this, ScopeKind::CaseVars, !IsActive);
22632303

22642304
ParserStatus Status;
22652305

lib/SILGen/SILGenPattern.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2616,7 +2616,7 @@ void SILGenFunction::emitSwitchStmt(SwitchStmt *S) {
26162616
// We use std::vector because it supports emplace_back; moving a ClauseRow is
26172617
// expensive.
26182618
std::vector<ClauseRow> clauseRows;
2619-
clauseRows.reserve(S->getCases().size());
2619+
clauseRows.reserve(S->getRawCases().size());
26202620
bool hasFallthrough = false;
26212621
for (auto caseBlock : S->getCases()) {
26222622
for (auto &labelItem : caseBlock->getCaseLabelItems()) {

lib/Sema/DerivedConformanceCodingKey.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ deriveBodyCodingKey_enum_stringValue(AbstractFunctionDecl *strValDecl) {
271271
body = BraceStmt::create(C, SourceLoc(), ASTNode(returnStmt),
272272
SourceLoc());
273273
} else {
274-
SmallVector<CaseStmt *, 4> cases;
274+
SmallVector<ASTNode, 4> cases;
275275
for (auto *elt : elements) {
276276
auto *pat = new (C) EnumElementPattern(TypeLoc::withoutLoc(enumType),
277277
SourceLoc(), SourceLoc(),
@@ -336,7 +336,7 @@ deriveBodyCodingKey_init_stringValue(AbstractFunctionDecl *initDecl) {
336336
}
337337

338338
auto *selfRef = createSelfDeclRef(initDecl);
339-
SmallVector<CaseStmt *, 4> cases;
339+
SmallVector<ASTNode, 4> cases;
340340
for (auto *elt : elements) {
341341
auto *litExpr = new (C) StringLiteralExpr(elt->getNameStr(), SourceRange(),
342342
/*Implicit=*/true);

lib/Sema/DerivedConformanceEquatableHashable.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ static DeclRefExpr *convertEnumToIndex(SmallVectorImpl<ASTNode> &stmts,
8888
indexPat, nullptr, funcDecl);
8989

9090
unsigned index = 0;
91-
SmallVector<CaseStmt*, 4> cases;
91+
SmallVector<ASTNode, 4> cases;
9292
for (auto elt : enumDecl->getAllElements()) {
9393
// generate: case .<Case>:
9494
auto pat = new (C) EnumElementPattern(TypeLoc::withoutLoc(enumType),

lib/Sema/DerivedConformanceRawRepresentable.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ static void deriveBodyRawRepresentable_raw(AbstractFunctionDecl *toRawDecl) {
9393

9494
Type enumType = parentDC->getDeclaredTypeInContext();
9595

96-
SmallVector<CaseStmt*, 4> cases;
96+
SmallVector<ASTNode, 4> cases;
9797
for (auto elt : enumDecl->getAllElements()) {
9898
auto pat = new (C) EnumElementPattern(TypeLoc::withoutLoc(enumType),
9999
SourceLoc(), SourceLoc(),
@@ -198,7 +198,7 @@ deriveBodyRawRepresentable_init(AbstractFunctionDecl *initDecl) {
198198

199199
auto selfDecl = cast<ConstructorDecl>(initDecl)->getImplicitSelfDecl();
200200

201-
SmallVector<CaseStmt*, 4> cases;
201+
SmallVector<ASTNode, 4> cases;
202202
for (auto elt : enumDecl->getAllElements()) {
203203
auto litExpr = cloneRawLiteralExpr(C, elt->getRawValueExpr());
204204
auto litPat = new (C) ExprPattern(litExpr, /*isResolved*/ true,

lib/Sema/TypeCheckStmt.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -863,11 +863,12 @@ class StmtChecker : public StmtVisitor<StmtChecker, Stmt*> {
863863
AddSwitchNest switchNest(*this);
864864
AddLabeledStmt labelNest(*this, S);
865865

866-
for (unsigned i = 0, e = S->getCases().size(); i < e; ++i) {
867-
auto *caseBlock = S->getCases()[i];
866+
auto cases = S->getCases();
867+
for (auto i = cases.begin(), e = cases.end(); i != e; ++i) {
868+
auto *caseBlock = *i;
868869
// Fallthrough transfers control to the next case block. In the
869870
// final case block, it is invalid.
870-
FallthroughDest = i+1 == e ? nullptr : S->getCases()[i+1];
871+
FallthroughDest = std::next(i) == e ? nullptr : *std::next(i);
871872

872873
for (auto &labelItem : caseBlock->getMutableCaseLabelItems()) {
873874
// Resolve the pattern in the label.

lib/Sema/TypeCheckSwitchStmt.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -917,8 +917,7 @@ namespace {
917917
bool sawDowngradablePattern = false;
918918
bool sawRedundantPattern = false;
919919
SmallVector<Space, 4> spaces;
920-
for (unsigned i = 0, e = Switch->getCases().size(); i < e; ++i) {
921-
auto *caseBlock = Switch->getCases()[i];
920+
for (auto *caseBlock : Switch->getCases()) {
922921
for (auto &caseItem : caseBlock->getCaseLabelItems()) {
923922
// 'where'-clauses on cases mean the case does not contribute to
924923
// the exhaustiveness of the pattern.

0 commit comments

Comments
 (0)