Skip to content

[ast]/[silgen] If a case stmt is a fallthrough source, tail allocate a pointer in the case stmt to the fallthrough case. #23284

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 92 additions & 60 deletions include/swift/AST/Stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,18 @@
#include "llvm/Support/TrailingObjects.h"

namespace swift {
class AnyPattern;
class ASTContext;
class ASTWalker;
class Decl;
class Expr;
class FuncDecl;
class Pattern;
class PatternBindingDecl;
class VarDecl;


class AnyPattern;
class ASTContext;
class ASTWalker;
class Decl;
class Expr;
class FuncDecl;
class Pattern;
class PatternBindingDecl;
class VarDecl;
class CaseStmt;

enum class StmtKind {
#define STMT(ID, PARENT) ID,
#define LAST_STMT(ID) Last_Stmt = ID,
Expand Down Expand Up @@ -920,6 +922,45 @@ class CaseLabelItem {
}
};

/// FallthroughStmt - The keyword "fallthrough".
class FallthroughStmt : public Stmt {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are no changes here to FallthroughStmt. I just ran into issues with my uses for it in CaseStmt and then discovered that I couldn't just use a forward declaration of it = /.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

*So I moved it above CaseStmt.

SourceLoc Loc;
CaseStmt *FallthroughSource;
CaseStmt *FallthroughDest;

public:
FallthroughStmt(SourceLoc Loc, Optional<bool> implicit = None)
: Stmt(StmtKind::Fallthrough, getDefaultImplicitFlag(implicit, Loc)),
Loc(Loc), FallthroughSource(nullptr), FallthroughDest(nullptr) {}

SourceLoc getLoc() const { return Loc; }

SourceRange getSourceRange() const { return Loc; }

/// Get the CaseStmt block from which the fallthrough transfers control.
/// Set during Sema. (May stay null if fallthrough is invalid.)
CaseStmt *getFallthroughSource() const { return FallthroughSource; }
void setFallthroughSource(CaseStmt *C) {
assert(!FallthroughSource && "fallthrough source already set?!");
FallthroughSource = C;
}

/// Get the CaseStmt block to which the fallthrough transfers control.
/// Set during Sema.
CaseStmt *getFallthroughDest() const {
assert(FallthroughDest && "fallthrough dest is not set until Sema");
return FallthroughDest;
}
void setFallthroughDest(CaseStmt *C) {
assert(!FallthroughDest && "fallthrough dest already set?!");
FallthroughDest = C;
}

static bool classof(const Stmt *S) {
return S->getKind() == StmtKind::Fallthrough;
}
};

/// A 'case' or 'default' block of a switch statement. Only valid as the
/// substatement of a SwitchStmt. A case block begins either with one or more
/// CaseLabelItems or a single 'default' label.
Expand All @@ -933,8 +974,10 @@ class CaseLabelItem {
/// default:
/// \endcode
///
class CaseStmt final : public Stmt,
private llvm::TrailingObjects<CaseStmt, CaseLabelItem> {
class CaseStmt final
: public Stmt,
private llvm::TrailingObjects<CaseStmt, FallthroughStmt *,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I decided to store the entire FallthroughStmt * rather than a pointer to the CaseStmt* returned from FallthroughStmt::getFallthroughDest(). This seems more flexible. That being said, I did not want to add it to the API surface, so I did not expose it directly.

CaseLabelItem> {
friend TrailingObjects;

SourceLoc UnknownAttrLoc;
Expand All @@ -943,24 +986,47 @@ class CaseStmt final : public Stmt,

llvm::PointerIntPair<Stmt *, 1, bool> BodyAndHasBoundDecls;

/// Set to true if we have a fallthrough.
///
/// TODO: Once we have CaseBodyVarDecls, use the bit in BodyAndHasBoundDecls
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be clear here: HasBoundDecls is true iff we have CaseBodyVarDecls to my commit adding those will fix that. I just did not want to create a field called BodyAndHasBoundDeclsAndHasFallthrough...

/// for this instead. This is separate now for staging reasons.
bool hasFallthrough;

CaseStmt(SourceLoc CaseLoc, ArrayRef<CaseLabelItem> CaseLabelItems,
bool HasBoundDecls, SourceLoc UnknownAttrLoc, SourceLoc ColonLoc,
Stmt *Body, Optional<bool> Implicit);
Stmt *Body, Optional<bool> Implicit,
NullablePtr<FallthroughStmt> fallthroughStmt);

public:
static CaseStmt *create(ASTContext &C, SourceLoc CaseLoc,
ArrayRef<CaseLabelItem> CaseLabelItems,
bool HasBoundDecls, SourceLoc UnknownAttrLoc,
SourceLoc ColonLoc, Stmt *Body,
Optional<bool> Implicit = None);
static CaseStmt *
create(ASTContext &C, SourceLoc CaseLoc,
ArrayRef<CaseLabelItem> CaseLabelItems, bool HasBoundDecls,
SourceLoc UnknownAttrLoc, SourceLoc ColonLoc, Stmt *Body,
Optional<bool> Implicit = None,
NullablePtr<FallthroughStmt> fallthroughStmt = nullptr);

ArrayRef<CaseLabelItem> getCaseLabelItems() const {
return {getTrailingObjects<CaseLabelItem>(), Bits.CaseStmt.NumPatterns};
}

MutableArrayRef<CaseLabelItem> getMutableCaseLabelItems() {
return {getTrailingObjects<CaseLabelItem>(), Bits.CaseStmt.NumPatterns};
}

unsigned getNumCaseLabelItems() const { return Bits.CaseStmt.NumPatterns; }

NullablePtr<CaseStmt> getFallthroughDest() const {
return const_cast<CaseStmt &>(*this).getFallthroughDest();
}

NullablePtr<CaseStmt> getFallthroughDest() {
if (!hasFallthrough)
return nullptr;
return (*getTrailingObjects<FallthroughStmt *>())->getFallthroughDest();
}

bool hasFallthroughDest() const { return hasFallthrough; }

Stmt *getBody() const { return BodyAndHasBoundDecls.getPointer(); }
void setBody(Stmt *body) { BodyAndHasBoundDecls.setPointer(body); }

Expand Down Expand Up @@ -991,6 +1057,14 @@ class CaseStmt final : public Stmt,
}

static bool classof(const Stmt *S) { return S->getKind() == StmtKind::Case; }

size_t numTrailingObjects(OverloadToken<CaseLabelItem>) const {
return getNumCaseLabelItems();
}

size_t numTrailingObjects(OverloadToken<FallthroughStmt *>) const {
return hasFallthrough ? 1 : 0;
}
};

/// Switch statement.
Expand Down Expand Up @@ -1135,48 +1209,6 @@ class ContinueStmt : public Stmt {
}
};

/// FallthroughStmt - The keyword "fallthrough".
class FallthroughStmt : public Stmt {
SourceLoc Loc;
CaseStmt *FallthroughSource;
CaseStmt *FallthroughDest;

public:
FallthroughStmt(SourceLoc Loc, Optional<bool> implicit = None)
: Stmt(StmtKind::Fallthrough, getDefaultImplicitFlag(implicit, Loc)),
Loc(Loc), FallthroughSource(nullptr), FallthroughDest(nullptr)
{}

SourceLoc getLoc() const { return Loc; }

SourceRange getSourceRange() const { return Loc; }

/// Get the CaseStmt block from which the fallthrough transfers control.
/// Set during Sema. (May stay null if fallthrough is invalid.)
CaseStmt *getFallthroughSource() const {
return FallthroughSource;
}
void setFallthroughSource(CaseStmt *C) {
assert(!FallthroughSource && "fallthrough source already set?!");
FallthroughSource = C;
}

/// Get the CaseStmt block to which the fallthrough transfers control.
/// Set during Sema.
CaseStmt *getFallthroughDest() const {
assert(FallthroughDest && "fallthrough dest is not set until Sema");
return FallthroughDest;
}
void setFallthroughDest(CaseStmt *C) {
assert(!FallthroughDest && "fallthrough dest already set?!");
FallthroughDest = C;
}

static bool classof(const Stmt *S) {
return S->getKind() == StmtKind::Fallthrough;
}
};

/// FailStmt - A statement that indicates a failable, which is currently
/// spelled as "return nil" and can only be used within failable initializers.
class FailStmt : public Stmt {
Expand Down
53 changes: 32 additions & 21 deletions lib/AST/Stmt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -385,33 +385,44 @@ SourceLoc CaseLabelItem::getEndLoc() const {
return CasePattern->getEndLoc();
}

CaseStmt::CaseStmt(SourceLoc CaseLoc, ArrayRef<CaseLabelItem> CaseLabelItems,
bool HasBoundDecls, SourceLoc UnknownAttrLoc,
SourceLoc ColonLoc, Stmt *Body, Optional<bool> Implicit)
: Stmt(StmtKind::Case, getDefaultImplicitFlag(Implicit, CaseLoc)),
UnknownAttrLoc(UnknownAttrLoc), CaseLoc(CaseLoc), ColonLoc(ColonLoc),
BodyAndHasBoundDecls(Body, HasBoundDecls) {
Bits.CaseStmt.NumPatterns = CaseLabelItems.size();
CaseStmt::CaseStmt(SourceLoc caseLoc, ArrayRef<CaseLabelItem> caseLabelItems,
bool hasBoundDecls, SourceLoc unknownAttrLoc,
SourceLoc colonLoc, Stmt *body, Optional<bool> implicit,
NullablePtr<FallthroughStmt> fallthroughStmt)
: Stmt(StmtKind::Case, getDefaultImplicitFlag(implicit, caseLoc)),
UnknownAttrLoc(unknownAttrLoc), CaseLoc(caseLoc), ColonLoc(colonLoc),
BodyAndHasBoundDecls(body, hasBoundDecls),
hasFallthrough(fallthroughStmt.isNonNull()) {
Bits.CaseStmt.NumPatterns = caseLabelItems.size();
assert(Bits.CaseStmt.NumPatterns > 0 &&
"case block must have at least one pattern");
MutableArrayRef<CaseLabelItem> Items{ getTrailingObjects<CaseLabelItem>(),
Bits.CaseStmt.NumPatterns };

for (unsigned i = 0; i < Bits.CaseStmt.NumPatterns; ++i) {
new (&Items[i]) CaseLabelItem(CaseLabelItems[i]);
Items[i].getPattern()->markOwnedByStatement(this);
if (hasFallthrough) {
*getTrailingObjects<FallthroughStmt *>() = fallthroughStmt.get();
}

MutableArrayRef<CaseLabelItem> items{getTrailingObjects<CaseLabelItem>(),
Bits.CaseStmt.NumPatterns};

for (unsigned i : range(Bits.CaseStmt.NumPatterns)) {
new (&items[i]) CaseLabelItem(caseLabelItems[i]);
items[i].getPattern()->markOwnedByStatement(this);
}
}

CaseStmt *CaseStmt::create(ASTContext &C, SourceLoc CaseLoc,
ArrayRef<CaseLabelItem> CaseLabelItems,
bool HasBoundDecls, SourceLoc UnknownAttrLoc,
SourceLoc ColonLoc, Stmt *Body,
Optional<bool> Implicit) {
void *Mem = C.Allocate(totalSizeToAlloc<CaseLabelItem>(CaseLabelItems.size()),
alignof(CaseStmt));
return ::new (Mem) CaseStmt(CaseLoc, CaseLabelItems, HasBoundDecls,
UnknownAttrLoc, ColonLoc, Body, Implicit);
CaseStmt *CaseStmt::create(ASTContext &ctx, SourceLoc caseLoc,
ArrayRef<CaseLabelItem> caseLabelItems,
bool hasBoundDecls, SourceLoc unknownAttrLoc,
SourceLoc colonLoc, Stmt *body,
Optional<bool> implicit,
NullablePtr<FallthroughStmt> fallthroughStmt) {
void *mem =
ctx.Allocate(totalSizeToAlloc<FallthroughStmt *, CaseLabelItem>(
fallthroughStmt.isNonNull(), caseLabelItems.size()),
alignof(CaseStmt));
return ::new (mem)
CaseStmt(caseLoc, caseLabelItems, hasBoundDecls, unknownAttrLoc, colonLoc,
body, implicit, fallthroughStmt);
}

SwitchStmt *SwitchStmt::create(LabeledStmtInfo LabelInfo, SourceLoc SwitchLoc,
Expand Down
49 changes: 44 additions & 5 deletions lib/Parse/ParseStmt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@
//
//===----------------------------------------------------------------------===//

#include "swift/Parse/Parser.h"
#include "swift/AST/ASTWalker.h"
#include "swift/AST/Attr.h"
#include "swift/AST/Decl.h"
#include "swift/Basic/Defer.h"
#include "swift/Basic/Version.h"
#include "swift/Parse/Lexer.h"
#include "swift/Parse/CodeCompletionCallbacks.h"
#include "swift/Parse/Lexer.h"
#include "swift/Parse/Parser.h"
#include "swift/Parse/SyntaxParsingContext.h"
#include "swift/Subsystems.h"
#include "swift/Syntax/TokenSyntax.h"
Expand Down Expand Up @@ -2350,6 +2351,43 @@ parseStmtCaseDefault(Parser &P, SourceLoc &CaseLoc,
return Status;
}

namespace {

struct FallthroughFinder : ASTWalker {
FallthroughStmt *result;

FallthroughFinder() : result(nullptr) {}

// We walk through statements. If we find a fallthrough, then we got what
// we came for.
std::pair<bool, Stmt *> walkToStmtPre(Stmt *s) override {
if (auto *f = dyn_cast<FallthroughStmt>(s)) {
result = f;
}

return {true, s};
}

// Expressions, patterns and decls cannot contain fallthrough statements, so
// there is no reason to walk into them.
std::pair<bool, Expr *> walkToExprPre(Expr *e) override { return {false, e}; }
std::pair<bool, Pattern *> walkToPatternPre(Pattern *p) override {
return {false, p};
}

bool walkToDeclPre(Decl *d) override { return false; }
bool walkToTypeLocPre(TypeLoc &tl) override { return false; }
bool walkToTypeReprPre(TypeRepr *t) override { return false; }

static FallthroughStmt *findFallthrough(Stmt *s) {
FallthroughFinder finder;
s->walk(finder);
return finder.result;
}
};

} // end anonymous namespace

ParserResult<CaseStmt> Parser::parseStmtCase(bool IsActive) {
SyntaxParsingContext CaseContext(SyntaxContext, SyntaxKind::SwitchCase);
// A case block has its own scope for variables bound out of the pattern.
Expand Down Expand Up @@ -2424,9 +2462,10 @@ ParserResult<CaseStmt> Parser::parseStmtCase(bool IsActive) {
}

return makeParserResult(
Status, CaseStmt::create(Context, CaseLoc, CaseLabelItems,
!BoundDecls.empty(), UnknownAttrLoc, ColonLoc,
Body));
Status,
CaseStmt::create(Context, CaseLoc, CaseLabelItems, !BoundDecls.empty(),
UnknownAttrLoc, ColonLoc, Body, None,
FallthroughFinder::findFallthrough(Body)));
}

/// stmt-pound-assert:
Expand Down
Loading