Skip to content

Commit d07eb71

Browse files
authored
Merge pull request #9457 from rintaro/parse-ifconfig-switchcase
[Parse] Allow #if to guard switch case clauses
2 parents a2b89a7 + f6310b4 commit d07eb71

16 files changed

+517
-76
lines changed

include/swift/AST/Stmt.h

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -932,7 +932,7 @@ class CaseStmt final : public Stmt,
932932

933933
/// Switch statement.
934934
class SwitchStmt final : public LabeledStmt,
935-
private llvm::TrailingObjects<SwitchStmt, CaseStmt *> {
935+
private llvm::TrailingObjects<SwitchStmt, ASTNode> {
936936
friend TrailingObjects;
937937

938938
SourceLoc SwitchLoc, LBraceLoc, RBraceLoc;
@@ -953,7 +953,7 @@ class SwitchStmt final : public LabeledStmt,
953953
static SwitchStmt *create(LabeledStmtInfo LabelInfo, SourceLoc SwitchLoc,
954954
Expr *SubjectExpr,
955955
SourceLoc LBraceLoc,
956-
ArrayRef<CaseStmt*> Cases,
956+
ArrayRef<ASTNode> Cases,
957957
SourceLoc RBraceLoc,
958958
ASTContext &C);
959959

@@ -972,10 +972,28 @@ class SwitchStmt final : public LabeledStmt,
972972
/// Get the subject expression of the switch.
973973
Expr *getSubjectExpr() const { return SubjectExpr; }
974974
void setSubjectExpr(Expr *e) { SubjectExpr = e; }
975+
976+
ArrayRef<ASTNode> getRawCases() const {
977+
return {getTrailingObjects<ASTNode>(), CaseCount};
978+
}
979+
980+
private:
981+
struct AsCaseStmtWithSkippingIfConfig {
982+
AsCaseStmtWithSkippingIfConfig() {}
983+
Optional<CaseStmt*> operator()(const ASTNode &N) const {
984+
if (auto *CS = llvm::dyn_cast_or_null<CaseStmt>(N.dyn_cast<Stmt*>()))
985+
return CS;
986+
return None;
987+
}
988+
};
989+
990+
public:
991+
using AsCaseStmtRange = OptionalTransformRange<ArrayRef<ASTNode>,
992+
AsCaseStmtWithSkippingIfConfig>;
975993

976994
/// Get the list of case clauses.
977-
ArrayRef<CaseStmt*> getCases() const {
978-
return {getTrailingObjects<CaseStmt*>(), CaseCount};
995+
AsCaseStmtRange getCases() const {
996+
return AsCaseStmtRange(getRawCases(), AsCaseStmtWithSkippingIfConfig());
979997
}
980998

981999
static bool classof(const Stmt *S) {

include/swift/Parse/Parser.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1269,6 +1269,8 @@ class Parser {
12691269
// Statement Parsing
12701270

12711271
bool isStartOfStmt();
1272+
bool isTerminatorForBraceItemListKind(BraceItemListKind Kind,
1273+
ArrayRef<ASTNode> ParsedDecls);
12721274
ParserResult<Stmt> parseStmt();
12731275
ParserStatus parseExprOrStmt(ASTNode &Result);
12741276
ParserResult<Stmt> parseStmtBreak();
@@ -1291,7 +1293,8 @@ class Parser {
12911293
ParserResult<Stmt> parseStmtForEach(SourceLoc ForLoc,
12921294
LabeledStmtInfo LabelInfo);
12931295
ParserResult<Stmt> parseStmtSwitch(LabeledStmtInfo LabelInfo);
1294-
ParserResult<CaseStmt> parseStmtCase();
1296+
ParserStatus parseStmtCases(SmallVectorImpl<ASTNode> &cases, bool IsActive);
1297+
ParserResult<CaseStmt> parseStmtCase(bool IsActive);
12951298

12961299
//===--------------------------------------------------------------------===//
12971300
// Generics Parsing

lib/AST/ASTDumper.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1522,9 +1522,12 @@ class PrintStmt : public StmtVisitor<PrintStmt> {
15221522
void visitSwitchStmt(SwitchStmt *S) {
15231523
printCommon(S, "switch_stmt") << '\n';
15241524
printRec(S->getSubjectExpr());
1525-
for (CaseStmt *C : S->getCases()) {
1525+
for (auto N : S->getRawCases()) {
15261526
OS << '\n';
1527-
printRec(C);
1527+
if (N.is<Stmt*>())
1528+
printRec(N.get<Stmt*>());
1529+
else
1530+
printRec(N.get<Decl*>());
15281531
}
15291532
PrintWithColorRAII(OS, ParenthesisColor) << ')';
15301533
}

lib/AST/ASTPrinter.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3364,8 +3364,11 @@ void PrintAST::visitSwitchStmt(SwitchStmt *stmt) {
33643364
// FIXME: print subject
33653365
Printer << "{";
33663366
Printer.printNewline();
3367-
for (CaseStmt *C : stmt->getCases()) {
3368-
visit(C);
3367+
for (auto N : stmt->getRawCases()) {
3368+
if (N.is<Stmt*>())
3369+
visit(cast<CaseStmt>(N.get<Stmt*>()));
3370+
else
3371+
visit(cast<IfConfigDecl>(N.get<Decl*>()));
33693372
}
33703373
Printer.printNewline();
33713374
indent();

lib/AST/ASTWalker.cpp

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1452,12 +1452,19 @@ Stmt *Traversal::visitSwitchStmt(SwitchStmt *S) {
14521452
else
14531453
return nullptr;
14541454

1455-
for (CaseStmt *aCase : S->getCases()) {
1456-
if (Stmt *aStmt = doIt(aCase)) {
1457-
assert(aCase == aStmt && "switch case remap not supported");
1458-
(void)aStmt;
1459-
} else
1460-
return nullptr;
1455+
for (auto N : S->getRawCases()) {
1456+
if (Stmt *aCase = N.dyn_cast<Stmt*>()) {
1457+
assert(isa<CaseStmt>(aCase));
1458+
if (Stmt *aStmt = doIt(aCase)) {
1459+
assert(aCase == aStmt && "switch case remap not supported");
1460+
(void)aStmt;
1461+
} else
1462+
return nullptr;
1463+
} else {
1464+
assert(isa<IfConfigDecl>(N.get<Decl*>()));
1465+
if (doIt(N.get<Decl*>()))
1466+
return nullptr;
1467+
}
14611468
}
14621469

14631470
return S;

lib/AST/Stmt.cpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -412,15 +412,22 @@ CaseStmt *CaseStmt::create(ASTContext &C, SourceLoc CaseLoc,
412412
SwitchStmt *SwitchStmt::create(LabeledStmtInfo LabelInfo, SourceLoc SwitchLoc,
413413
Expr *SubjectExpr,
414414
SourceLoc LBraceLoc,
415-
ArrayRef<CaseStmt *> Cases,
415+
ArrayRef<ASTNode> Cases,
416416
SourceLoc RBraceLoc,
417417
ASTContext &C) {
418-
void *p = C.Allocate(totalSizeToAlloc<CaseStmt *>(Cases.size()),
418+
#ifndef NDEBUG
419+
for (auto N : Cases)
420+
assert((N.is<Stmt*>() && isa<CaseStmt>(N.get<Stmt*>())) ||
421+
(N.is<Decl*>() && isa<IfConfigDecl>(N.get<Decl*>())));
422+
#endif
423+
424+
void *p = C.Allocate(totalSizeToAlloc<ASTNode>(Cases.size()),
419425
alignof(SwitchStmt));
420426
SwitchStmt *theSwitch = ::new (p) SwitchStmt(LabelInfo, SwitchLoc,
421427
SubjectExpr, LBraceLoc,
422428
Cases.size(), RBraceLoc);
429+
423430
std::uninitialized_copy(Cases.begin(), Cases.end(),
424-
theSwitch->getTrailingObjects<CaseStmt *>());
431+
theSwitch->getTrailingObjects<ASTNode>());
425432
return theSwitch;
426433
}

lib/Parse/ParseStmt.cpp

Lines changed: 76 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -130,14 +130,25 @@ ParserStatus Parser::parseExprOrStmt(ASTNode &Result) {
130130
return ResultExpr;
131131
}
132132

133-
static bool isTerminatorForBraceItemListKind(const Token &Tok,
134-
BraceItemListKind Kind,
135-
ArrayRef<ASTNode> ParsedDecls) {
133+
bool Parser::isTerminatorForBraceItemListKind(BraceItemListKind Kind,
134+
ArrayRef<ASTNode> ParsedDecls) {
136135
switch (Kind) {
137136
case BraceItemListKind::Brace:
138137
return false;
139138
case BraceItemListKind::Case:
140-
return Tok.is(tok::kw_case) || Tok.is(tok::kw_default);
139+
if (Tok.is(tok::pound_if)) {
140+
// '#if' here could be to guard 'case:' or statements in cases.
141+
// If the next non-directive line starts with 'case' or 'default', it is
142+
// for 'case's.
143+
Parser::BacktrackingScope Backtrack(*this);
144+
do {
145+
consumeToken();
146+
while (!Tok.isAtStartOfLine() && Tok.isNot(tok::eof))
147+
skipSingle();
148+
} while (Tok.isAny(tok::pound_if, tok::pound_elseif, tok::pound_else));
149+
return Tok.isAny(tok::kw_case, tok::kw_default);
150+
}
151+
return Tok.isAny(tok::kw_case, tok::kw_default);
141152
case BraceItemListKind::TopLevelCode:
142153
// When parsing the top level executable code for a module, if we parsed
143154
// some executable code, then we're done. We want to process (name bind,
@@ -247,7 +258,7 @@ ParserStatus Parser::parseBraceItems(SmallVectorImpl<ASTNode> &Entries,
247258
Tok.isNot(tok::kw_sil_witness_table) &&
248259
Tok.isNot(tok::kw_sil_default_witness_table) &&
249260
(isConditionalBlock ||
250-
!isTerminatorForBraceItemListKind(Tok, Kind, Entries))) {
261+
!isTerminatorForBraceItemListKind(Kind, Entries))) {
251262
if (Kind == BraceItemListKind::TopLevelLibrary &&
252263
skipExtraTopLevelRBraces())
253264
continue;
@@ -2152,36 +2163,20 @@ ParserResult<Stmt> Parser::parseStmtSwitch(LabeledStmtInfo LabelInfo) {
21522163
SourceLoc lBraceLoc = consumeToken(tok::l_brace);
21532164
SourceLoc rBraceLoc;
21542165

2155-
// If there are non-case-label statements at the start of the switch body,
2156-
// raise an error and recover by discarding them.
2157-
bool DiagnosedNotCoveredStmt = false;
2158-
while (!Tok.is(tok::kw_case) && !Tok.is(tok::kw_default)
2159-
&& !Tok.is(tok::r_brace) && !Tok.is(tok::eof)) {
2160-
if (!DiagnosedNotCoveredStmt) {
2161-
diagnose(Tok, diag::stmt_in_switch_not_covered_by_case);
2162-
DiagnosedNotCoveredStmt = true;
2163-
}
2164-
skipSingle();
2165-
}
2166-
2167-
SmallVector<CaseStmt*, 8> cases;
2168-
bool parsedDefault = false;
2169-
bool parsedBlockAfterDefault = false;
2170-
while (Tok.is(tok::kw_case) || Tok.is(tok::kw_default)) {
2171-
// We cannot have additional cases after a default clause. Complain on
2172-
// the first offender.
2173-
if (parsedDefault && !parsedBlockAfterDefault) {
2174-
parsedBlockAfterDefault = true;
2175-
diagnose(Tok, diag::case_after_default);
2176-
}
2177-
2178-
ParserResult<CaseStmt> Case = parseStmtCase();
2179-
Status |= Case;
2180-
if (Case.isNonNull()) {
2181-
cases.push_back(Case.get());
2182-
if (Case.get()->isDefault())
2183-
parsedDefault = true;
2166+
SmallVector<ASTNode, 8> cases;
2167+
Status |= parseStmtCases(cases, /*IsActive=*/true);
2168+
2169+
// We cannot have additional cases after a default clause. Complain on
2170+
// the first offender.
2171+
bool hasDefault = false;
2172+
for (auto Element : cases) {
2173+
if (!Element.is<Stmt*>()) continue;
2174+
auto *CS = cast<CaseStmt>(Element.get<Stmt*>());
2175+
if (hasDefault) {
2176+
diagnose(CS->getLoc(), diag::case_after_default);
2177+
break;
21842178
}
2179+
hasDefault |= CS->isDefault();
21852180
}
21862181

21872182
if (parseMatchingToken(tok::r_brace, rBraceLoc,
@@ -2194,6 +2189,51 @@ ParserResult<Stmt> Parser::parseStmtSwitch(LabeledStmtInfo LabelInfo) {
21942189
lBraceLoc, cases, rBraceLoc, Context));
21952190
}
21962191

2192+
ParserStatus
2193+
Parser::parseStmtCases(SmallVectorImpl<ASTNode> &cases, bool IsActive) {
2194+
ParserStatus Status;
2195+
while (Tok.isNot(tok::r_brace, tok::eof,
2196+
tok::pound_endif, tok::pound_elseif, tok::pound_else)) {
2197+
if (Tok.isAny(tok::kw_case, tok::kw_default)) {
2198+
ParserResult<CaseStmt> Case = parseStmtCase(IsActive);
2199+
Status |= Case;
2200+
if (Case.isNonNull())
2201+
cases.emplace_back(Case.get());
2202+
} else if (Tok.is(tok::pound_if)) {
2203+
// '#if' in 'case' position can enclose one or more 'case' or 'default'
2204+
// clauses.
2205+
auto IfConfigResult = parseIfConfig(
2206+
[&](SmallVectorImpl<ASTNode> &Elements, bool IsActive) {
2207+
parseStmtCases(Elements, IsActive);
2208+
});
2209+
Status |= IfConfigResult;
2210+
if (auto ICD = IfConfigResult.getPtrOrNull()) {
2211+
cases.emplace_back(ICD);
2212+
2213+
for (auto &Entry : ICD->getActiveClauseElements()) {
2214+
if (Entry.is<Decl*>() && isa<IfConfigDecl>(Entry.get<Decl*>()))
2215+
// Don't hoist nested '#if'.
2216+
continue;
2217+
2218+
assert(Entry.is<Stmt*>() && isa<CaseStmt>(Entry.get<Stmt*>()));
2219+
cases.push_back(Entry);
2220+
}
2221+
}
2222+
} else {
2223+
// If there are non-case-label statements at the start of the switch body,
2224+
// raise an error and recover by discarding them.
2225+
diagnose(Tok, diag::stmt_in_switch_not_covered_by_case);
2226+
2227+
while (Tok.isNot(tok::r_brace, tok::eof, tok::pound_elseif,
2228+
tok::pound_else, tok::pound_endif) &&
2229+
!isTerminatorForBraceItemListKind(BraceItemListKind::Case, {})) {
2230+
skipSingle();
2231+
}
2232+
}
2233+
}
2234+
return Status;
2235+
}
2236+
21972237
static ParserStatus parseStmtCase(Parser &P, SourceLoc &CaseLoc,
21982238
SmallVectorImpl<CaseLabelItem> &LabelItems,
21992239
SmallVectorImpl<VarDecl *> &BoundDecls,
@@ -2258,9 +2298,9 @@ parseStmtCaseDefault(Parser &P, SourceLoc &CaseLoc,
22582298
return Status;
22592299
}
22602300

2261-
ParserResult<CaseStmt> Parser::parseStmtCase() {
2301+
ParserResult<CaseStmt> Parser::parseStmtCase(bool IsActive) {
22622302
// A case block has its own scope for variables bound out of the pattern.
2263-
Scope S(this, ScopeKind::CaseVars);
2303+
Scope S(this, ScopeKind::CaseVars, !IsActive);
22642304

22652305
ParserStatus Status;
22662306

lib/SILGen/SILGenPattern.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2617,7 +2617,7 @@ void SILGenFunction::emitSwitchStmt(SwitchStmt *S) {
26172617
// We use std::vector because it supports emplace_back; moving a ClauseRow is
26182618
// expensive.
26192619
std::vector<ClauseRow> clauseRows;
2620-
clauseRows.reserve(S->getCases().size());
2620+
clauseRows.reserve(S->getRawCases().size());
26212621
bool hasFallthrough = false;
26222622
for (auto caseBlock : S->getCases()) {
26232623
for (auto &labelItem : caseBlock->getCaseLabelItems()) {

lib/Sema/DerivedConformanceCodingKey.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ deriveBodyCodingKey_enum_stringValue(AbstractFunctionDecl *strValDecl) {
271271
body = BraceStmt::create(C, SourceLoc(), ASTNode(returnStmt),
272272
SourceLoc());
273273
} else {
274-
SmallVector<CaseStmt *, 4> cases;
274+
SmallVector<ASTNode, 4> cases;
275275
for (auto *elt : elements) {
276276
auto *pat = new (C) EnumElementPattern(TypeLoc::withoutLoc(enumType),
277277
SourceLoc(), SourceLoc(),
@@ -336,7 +336,7 @@ deriveBodyCodingKey_init_stringValue(AbstractFunctionDecl *initDecl) {
336336
}
337337

338338
auto *selfRef = createSelfDeclRef(initDecl);
339-
SmallVector<CaseStmt *, 4> cases;
339+
SmallVector<ASTNode, 4> cases;
340340
for (auto *elt : elements) {
341341
auto *litExpr = new (C) StringLiteralExpr(elt->getNameStr(), SourceRange(),
342342
/*Implicit=*/true);

lib/Sema/DerivedConformanceEquatableHashable.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ static DeclRefExpr *convertEnumToIndex(SmallVectorImpl<ASTNode> &stmts,
8888
indexPat, nullptr, funcDecl);
8989

9090
unsigned index = 0;
91-
SmallVector<CaseStmt*, 4> cases;
91+
SmallVector<ASTNode, 4> cases;
9292
for (auto elt : enumDecl->getAllElements()) {
9393
// generate: case .<Case>:
9494
auto pat = new (C) EnumElementPattern(TypeLoc::withoutLoc(enumType),

lib/Sema/DerivedConformanceRawRepresentable.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ static void deriveBodyRawRepresentable_raw(AbstractFunctionDecl *toRawDecl) {
9393

9494
Type enumType = parentDC->getDeclaredTypeInContext();
9595

96-
SmallVector<CaseStmt*, 4> cases;
96+
SmallVector<ASTNode, 4> cases;
9797
for (auto elt : enumDecl->getAllElements()) {
9898
auto pat = new (C) EnumElementPattern(TypeLoc::withoutLoc(enumType),
9999
SourceLoc(), SourceLoc(),
@@ -198,7 +198,7 @@ deriveBodyRawRepresentable_init(AbstractFunctionDecl *initDecl) {
198198

199199
auto selfDecl = cast<ConstructorDecl>(initDecl)->getImplicitSelfDecl();
200200

201-
SmallVector<CaseStmt*, 4> cases;
201+
SmallVector<ASTNode, 4> cases;
202202
for (auto elt : enumDecl->getAllElements()) {
203203
auto litExpr = cloneRawLiteralExpr(C, elt->getRawValueExpr());
204204
auto litPat = new (C) ExprPattern(litExpr, /*isResolved*/ true,

lib/Sema/TypeCheckStmt.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -863,11 +863,12 @@ class StmtChecker : public StmtVisitor<StmtChecker, Stmt*> {
863863
AddSwitchNest switchNest(*this);
864864
AddLabeledStmt labelNest(*this, S);
865865

866-
for (unsigned i = 0, e = S->getCases().size(); i < e; ++i) {
867-
auto *caseBlock = S->getCases()[i];
866+
auto cases = S->getCases();
867+
for (auto i = cases.begin(), e = cases.end(); i != e; ++i) {
868+
auto *caseBlock = *i;
868869
// Fallthrough transfers control to the next case block. In the
869870
// final case block, it is invalid.
870-
FallthroughDest = i+1 == e ? nullptr : S->getCases()[i+1];
871+
FallthroughDest = std::next(i) == e ? nullptr : *std::next(i);
871872

872873
for (auto &labelItem : caseBlock->getMutableCaseLabelItems()) {
873874
// Resolve the pattern in the label.

lib/Sema/TypeCheckSwitchStmt.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -931,8 +931,7 @@ namespace {
931931
bool sawDowngradablePattern = false;
932932
bool sawRedundantPattern = false;
933933
SmallVector<Space, 4> spaces;
934-
for (unsigned i = 0, e = Switch->getCases().size(); i < e; ++i) {
935-
auto *caseBlock = Switch->getCases()[i];
934+
for (auto *caseBlock : Switch->getCases()) {
936935
for (auto &caseItem : caseBlock->getCaseLabelItems()) {
937936
// 'where'-clauses on cases mean the case does not contribute to
938937
// the exhaustiveness of the pattern.

0 commit comments

Comments
 (0)