Skip to content

Commit 1e0414d

Browse files
authored
Merge pull request #23284 from gottesmm/pr-0b4de614fe103f8e5a161be5219ba2cb7148aa25
[ast]/[silgen] If a case stmt is a fallthrough source, tail allocate a pointer in the case stmt to the fallthrough case.
2 parents dee8d7b + 7b27b4b commit 1e0414d

File tree

4 files changed

+170
-126
lines changed

4 files changed

+170
-126
lines changed

include/swift/AST/Stmt.h

Lines changed: 92 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -26,16 +26,18 @@
2626
#include "llvm/Support/TrailingObjects.h"
2727

2828
namespace swift {
29-
class AnyPattern;
30-
class ASTContext;
31-
class ASTWalker;
32-
class Decl;
33-
class Expr;
34-
class FuncDecl;
35-
class Pattern;
36-
class PatternBindingDecl;
37-
class VarDecl;
38-
29+
30+
class AnyPattern;
31+
class ASTContext;
32+
class ASTWalker;
33+
class Decl;
34+
class Expr;
35+
class FuncDecl;
36+
class Pattern;
37+
class PatternBindingDecl;
38+
class VarDecl;
39+
class CaseStmt;
40+
3941
enum class StmtKind {
4042
#define STMT(ID, PARENT) ID,
4143
#define LAST_STMT(ID) Last_Stmt = ID,
@@ -920,6 +922,45 @@ class CaseLabelItem {
920922
}
921923
};
922924

925+
/// FallthroughStmt - The keyword "fallthrough".
926+
class FallthroughStmt : public Stmt {
927+
SourceLoc Loc;
928+
CaseStmt *FallthroughSource;
929+
CaseStmt *FallthroughDest;
930+
931+
public:
932+
FallthroughStmt(SourceLoc Loc, Optional<bool> implicit = None)
933+
: Stmt(StmtKind::Fallthrough, getDefaultImplicitFlag(implicit, Loc)),
934+
Loc(Loc), FallthroughSource(nullptr), FallthroughDest(nullptr) {}
935+
936+
SourceLoc getLoc() const { return Loc; }
937+
938+
SourceRange getSourceRange() const { return Loc; }
939+
940+
/// Get the CaseStmt block from which the fallthrough transfers control.
941+
/// Set during Sema. (May stay null if fallthrough is invalid.)
942+
CaseStmt *getFallthroughSource() const { return FallthroughSource; }
943+
void setFallthroughSource(CaseStmt *C) {
944+
assert(!FallthroughSource && "fallthrough source already set?!");
945+
FallthroughSource = C;
946+
}
947+
948+
/// Get the CaseStmt block to which the fallthrough transfers control.
949+
/// Set during Sema.
950+
CaseStmt *getFallthroughDest() const {
951+
assert(FallthroughDest && "fallthrough dest is not set until Sema");
952+
return FallthroughDest;
953+
}
954+
void setFallthroughDest(CaseStmt *C) {
955+
assert(!FallthroughDest && "fallthrough dest already set?!");
956+
FallthroughDest = C;
957+
}
958+
959+
static bool classof(const Stmt *S) {
960+
return S->getKind() == StmtKind::Fallthrough;
961+
}
962+
};
963+
923964
/// A 'case' or 'default' block of a switch statement. Only valid as the
924965
/// substatement of a SwitchStmt. A case block begins either with one or more
925966
/// CaseLabelItems or a single 'default' label.
@@ -933,8 +974,10 @@ class CaseLabelItem {
933974
/// default:
934975
/// \endcode
935976
///
936-
class CaseStmt final : public Stmt,
937-
private llvm::TrailingObjects<CaseStmt, CaseLabelItem> {
977+
class CaseStmt final
978+
: public Stmt,
979+
private llvm::TrailingObjects<CaseStmt, FallthroughStmt *,
980+
CaseLabelItem> {
938981
friend TrailingObjects;
939982

940983
SourceLoc UnknownAttrLoc;
@@ -943,24 +986,47 @@ class CaseStmt final : public Stmt,
943986

944987
llvm::PointerIntPair<Stmt *, 1, bool> BodyAndHasBoundDecls;
945988

989+
/// Set to true if we have a fallthrough.
990+
///
991+
/// TODO: Once we have CaseBodyVarDecls, use the bit in BodyAndHasBoundDecls
992+
/// for this instead. This is separate now for staging reasons.
993+
bool hasFallthrough;
994+
946995
CaseStmt(SourceLoc CaseLoc, ArrayRef<CaseLabelItem> CaseLabelItems,
947996
bool HasBoundDecls, SourceLoc UnknownAttrLoc, SourceLoc ColonLoc,
948-
Stmt *Body, Optional<bool> Implicit);
997+
Stmt *Body, Optional<bool> Implicit,
998+
NullablePtr<FallthroughStmt> fallthroughStmt);
949999

9501000
public:
951-
static CaseStmt *create(ASTContext &C, SourceLoc CaseLoc,
952-
ArrayRef<CaseLabelItem> CaseLabelItems,
953-
bool HasBoundDecls, SourceLoc UnknownAttrLoc,
954-
SourceLoc ColonLoc, Stmt *Body,
955-
Optional<bool> Implicit = None);
1001+
static CaseStmt *
1002+
create(ASTContext &C, SourceLoc CaseLoc,
1003+
ArrayRef<CaseLabelItem> CaseLabelItems, bool HasBoundDecls,
1004+
SourceLoc UnknownAttrLoc, SourceLoc ColonLoc, Stmt *Body,
1005+
Optional<bool> Implicit = None,
1006+
NullablePtr<FallthroughStmt> fallthroughStmt = nullptr);
9561007

9571008
ArrayRef<CaseLabelItem> getCaseLabelItems() const {
9581009
return {getTrailingObjects<CaseLabelItem>(), Bits.CaseStmt.NumPatterns};
9591010
}
1011+
9601012
MutableArrayRef<CaseLabelItem> getMutableCaseLabelItems() {
9611013
return {getTrailingObjects<CaseLabelItem>(), Bits.CaseStmt.NumPatterns};
9621014
}
9631015

1016+
unsigned getNumCaseLabelItems() const { return Bits.CaseStmt.NumPatterns; }
1017+
1018+
NullablePtr<CaseStmt> getFallthroughDest() const {
1019+
return const_cast<CaseStmt &>(*this).getFallthroughDest();
1020+
}
1021+
1022+
NullablePtr<CaseStmt> getFallthroughDest() {
1023+
if (!hasFallthrough)
1024+
return nullptr;
1025+
return (*getTrailingObjects<FallthroughStmt *>())->getFallthroughDest();
1026+
}
1027+
1028+
bool hasFallthroughDest() const { return hasFallthrough; }
1029+
9641030
Stmt *getBody() const { return BodyAndHasBoundDecls.getPointer(); }
9651031
void setBody(Stmt *body) { BodyAndHasBoundDecls.setPointer(body); }
9661032

@@ -991,6 +1057,14 @@ class CaseStmt final : public Stmt,
9911057
}
9921058

9931059
static bool classof(const Stmt *S) { return S->getKind() == StmtKind::Case; }
1060+
1061+
size_t numTrailingObjects(OverloadToken<CaseLabelItem>) const {
1062+
return getNumCaseLabelItems();
1063+
}
1064+
1065+
size_t numTrailingObjects(OverloadToken<FallthroughStmt *>) const {
1066+
return hasFallthrough ? 1 : 0;
1067+
}
9941068
};
9951069

9961070
/// Switch statement.
@@ -1135,48 +1209,6 @@ class ContinueStmt : public Stmt {
11351209
}
11361210
};
11371211

1138-
/// FallthroughStmt - The keyword "fallthrough".
1139-
class FallthroughStmt : public Stmt {
1140-
SourceLoc Loc;
1141-
CaseStmt *FallthroughSource;
1142-
CaseStmt *FallthroughDest;
1143-
1144-
public:
1145-
FallthroughStmt(SourceLoc Loc, Optional<bool> implicit = None)
1146-
: Stmt(StmtKind::Fallthrough, getDefaultImplicitFlag(implicit, Loc)),
1147-
Loc(Loc), FallthroughSource(nullptr), FallthroughDest(nullptr)
1148-
{}
1149-
1150-
SourceLoc getLoc() const { return Loc; }
1151-
1152-
SourceRange getSourceRange() const { return Loc; }
1153-
1154-
/// Get the CaseStmt block from which the fallthrough transfers control.
1155-
/// Set during Sema. (May stay null if fallthrough is invalid.)
1156-
CaseStmt *getFallthroughSource() const {
1157-
return FallthroughSource;
1158-
}
1159-
void setFallthroughSource(CaseStmt *C) {
1160-
assert(!FallthroughSource && "fallthrough source already set?!");
1161-
FallthroughSource = C;
1162-
}
1163-
1164-
/// Get the CaseStmt block to which the fallthrough transfers control.
1165-
/// Set during Sema.
1166-
CaseStmt *getFallthroughDest() const {
1167-
assert(FallthroughDest && "fallthrough dest is not set until Sema");
1168-
return FallthroughDest;
1169-
}
1170-
void setFallthroughDest(CaseStmt *C) {
1171-
assert(!FallthroughDest && "fallthrough dest already set?!");
1172-
FallthroughDest = C;
1173-
}
1174-
1175-
static bool classof(const Stmt *S) {
1176-
return S->getKind() == StmtKind::Fallthrough;
1177-
}
1178-
};
1179-
11801212
/// FailStmt - A statement that indicates a failable, which is currently
11811213
/// spelled as "return nil" and can only be used within failable initializers.
11821214
class FailStmt : public Stmt {

lib/AST/Stmt.cpp

Lines changed: 32 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -385,33 +385,44 @@ SourceLoc CaseLabelItem::getEndLoc() const {
385385
return CasePattern->getEndLoc();
386386
}
387387

388-
CaseStmt::CaseStmt(SourceLoc CaseLoc, ArrayRef<CaseLabelItem> CaseLabelItems,
389-
bool HasBoundDecls, SourceLoc UnknownAttrLoc,
390-
SourceLoc ColonLoc, Stmt *Body, Optional<bool> Implicit)
391-
: Stmt(StmtKind::Case, getDefaultImplicitFlag(Implicit, CaseLoc)),
392-
UnknownAttrLoc(UnknownAttrLoc), CaseLoc(CaseLoc), ColonLoc(ColonLoc),
393-
BodyAndHasBoundDecls(Body, HasBoundDecls) {
394-
Bits.CaseStmt.NumPatterns = CaseLabelItems.size();
388+
CaseStmt::CaseStmt(SourceLoc caseLoc, ArrayRef<CaseLabelItem> caseLabelItems,
389+
bool hasBoundDecls, SourceLoc unknownAttrLoc,
390+
SourceLoc colonLoc, Stmt *body, Optional<bool> implicit,
391+
NullablePtr<FallthroughStmt> fallthroughStmt)
392+
: Stmt(StmtKind::Case, getDefaultImplicitFlag(implicit, caseLoc)),
393+
UnknownAttrLoc(unknownAttrLoc), CaseLoc(caseLoc), ColonLoc(colonLoc),
394+
BodyAndHasBoundDecls(body, hasBoundDecls),
395+
hasFallthrough(fallthroughStmt.isNonNull()) {
396+
Bits.CaseStmt.NumPatterns = caseLabelItems.size();
395397
assert(Bits.CaseStmt.NumPatterns > 0 &&
396398
"case block must have at least one pattern");
397-
MutableArrayRef<CaseLabelItem> Items{ getTrailingObjects<CaseLabelItem>(),
398-
Bits.CaseStmt.NumPatterns };
399399

400-
for (unsigned i = 0; i < Bits.CaseStmt.NumPatterns; ++i) {
401-
new (&Items[i]) CaseLabelItem(CaseLabelItems[i]);
402-
Items[i].getPattern()->markOwnedByStatement(this);
400+
if (hasFallthrough) {
401+
*getTrailingObjects<FallthroughStmt *>() = fallthroughStmt.get();
402+
}
403+
404+
MutableArrayRef<CaseLabelItem> items{getTrailingObjects<CaseLabelItem>(),
405+
Bits.CaseStmt.NumPatterns};
406+
407+
for (unsigned i : range(Bits.CaseStmt.NumPatterns)) {
408+
new (&items[i]) CaseLabelItem(caseLabelItems[i]);
409+
items[i].getPattern()->markOwnedByStatement(this);
403410
}
404411
}
405412

406-
CaseStmt *CaseStmt::create(ASTContext &C, SourceLoc CaseLoc,
407-
ArrayRef<CaseLabelItem> CaseLabelItems,
408-
bool HasBoundDecls, SourceLoc UnknownAttrLoc,
409-
SourceLoc ColonLoc, Stmt *Body,
410-
Optional<bool> Implicit) {
411-
void *Mem = C.Allocate(totalSizeToAlloc<CaseLabelItem>(CaseLabelItems.size()),
412-
alignof(CaseStmt));
413-
return ::new (Mem) CaseStmt(CaseLoc, CaseLabelItems, HasBoundDecls,
414-
UnknownAttrLoc, ColonLoc, Body, Implicit);
413+
CaseStmt *CaseStmt::create(ASTContext &ctx, SourceLoc caseLoc,
414+
ArrayRef<CaseLabelItem> caseLabelItems,
415+
bool hasBoundDecls, SourceLoc unknownAttrLoc,
416+
SourceLoc colonLoc, Stmt *body,
417+
Optional<bool> implicit,
418+
NullablePtr<FallthroughStmt> fallthroughStmt) {
419+
void *mem =
420+
ctx.Allocate(totalSizeToAlloc<FallthroughStmt *, CaseLabelItem>(
421+
fallthroughStmt.isNonNull(), caseLabelItems.size()),
422+
alignof(CaseStmt));
423+
return ::new (mem)
424+
CaseStmt(caseLoc, caseLabelItems, hasBoundDecls, unknownAttrLoc, colonLoc,
425+
body, implicit, fallthroughStmt);
415426
}
416427

417428
SwitchStmt *SwitchStmt::create(LabeledStmtInfo LabelInfo, SourceLoc SwitchLoc,

lib/Parse/ParseStmt.cpp

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,14 @@
1414
//
1515
//===----------------------------------------------------------------------===//
1616

17-
#include "swift/Parse/Parser.h"
17+
#include "swift/AST/ASTWalker.h"
1818
#include "swift/AST/Attr.h"
1919
#include "swift/AST/Decl.h"
2020
#include "swift/Basic/Defer.h"
2121
#include "swift/Basic/Version.h"
22-
#include "swift/Parse/Lexer.h"
2322
#include "swift/Parse/CodeCompletionCallbacks.h"
23+
#include "swift/Parse/Lexer.h"
24+
#include "swift/Parse/Parser.h"
2425
#include "swift/Parse/SyntaxParsingContext.h"
2526
#include "swift/Subsystems.h"
2627
#include "swift/Syntax/TokenSyntax.h"
@@ -2350,6 +2351,43 @@ parseStmtCaseDefault(Parser &P, SourceLoc &CaseLoc,
23502351
return Status;
23512352
}
23522353

2354+
namespace {
2355+
2356+
struct FallthroughFinder : ASTWalker {
2357+
FallthroughStmt *result;
2358+
2359+
FallthroughFinder() : result(nullptr) {}
2360+
2361+
// We walk through statements. If we find a fallthrough, then we got what
2362+
// we came for.
2363+
std::pair<bool, Stmt *> walkToStmtPre(Stmt *s) override {
2364+
if (auto *f = dyn_cast<FallthroughStmt>(s)) {
2365+
result = f;
2366+
}
2367+
2368+
return {true, s};
2369+
}
2370+
2371+
// Expressions, patterns and decls cannot contain fallthrough statements, so
2372+
// there is no reason to walk into them.
2373+
std::pair<bool, Expr *> walkToExprPre(Expr *e) override { return {false, e}; }
2374+
std::pair<bool, Pattern *> walkToPatternPre(Pattern *p) override {
2375+
return {false, p};
2376+
}
2377+
2378+
bool walkToDeclPre(Decl *d) override { return false; }
2379+
bool walkToTypeLocPre(TypeLoc &tl) override { return false; }
2380+
bool walkToTypeReprPre(TypeRepr *t) override { return false; }
2381+
2382+
static FallthroughStmt *findFallthrough(Stmt *s) {
2383+
FallthroughFinder finder;
2384+
s->walk(finder);
2385+
return finder.result;
2386+
}
2387+
};
2388+
2389+
} // end anonymous namespace
2390+
23532391
ParserResult<CaseStmt> Parser::parseStmtCase(bool IsActive) {
23542392
SyntaxParsingContext CaseContext(SyntaxContext, SyntaxKind::SwitchCase);
23552393
// A case block has its own scope for variables bound out of the pattern.
@@ -2424,9 +2462,10 @@ ParserResult<CaseStmt> Parser::parseStmtCase(bool IsActive) {
24242462
}
24252463

24262464
return makeParserResult(
2427-
Status, CaseStmt::create(Context, CaseLoc, CaseLabelItems,
2428-
!BoundDecls.empty(), UnknownAttrLoc, ColonLoc,
2429-
Body));
2465+
Status,
2466+
CaseStmt::create(Context, CaseLoc, CaseLabelItems, !BoundDecls.empty(),
2467+
UnknownAttrLoc, ColonLoc, Body, None,
2468+
FallthroughFinder::findFallthrough(Body)));
24302469
}
24312470

24322471
/// stmt-pound-assert:

0 commit comments

Comments
 (0)