Skip to content

[Sema] Improve handling of fallthrough in if/switch expressions #75891

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions include/swift/AST/ASTBridging.h
Original file line number Diff line number Diff line change
Expand Up @@ -1491,10 +1491,10 @@ BridgedDoCatchStmt BridgedDoCatchStmt_createParsed(
BridgedNullableTypeRepr cThrownType, BridgedStmt cBody,
BridgedArrayRef cCatches);

SWIFT_NAME("BridgedFallthroughStmt.createParsed(_:loc:)")
SWIFT_NAME("BridgedFallthroughStmt.createParsed(loc:declContext:)")
BridgedFallthroughStmt
BridgedFallthroughStmt_createParsed(BridgedASTContext cContext,
BridgedSourceLoc cLoc);
BridgedFallthroughStmt_createParsed(BridgedSourceLoc cLoc,
BridgedDeclContext cDC);

SWIFT_NAME("BridgedForEachStmt.createParsed(_:labelInfo:forLoc:tryLoc:awaitLoc:"
"pattern:inLoc:sequence:whereLoc:whereExpr:body:)")
Expand Down
7 changes: 4 additions & 3 deletions include/swift/AST/DiagnosticsSema.def
Original file line number Diff line number Diff line change
Expand Up @@ -1328,10 +1328,11 @@ ERROR(single_value_stmt_branch_empty,none,
"expected expression in branch of '%0' expression",
(StmtKind))
ERROR(single_value_stmt_branch_must_end_in_result,none,
"non-expression branch of '%0' expression may only end with a 'throw'",
(StmtKind))
"non-expression branch of '%0' expression may only end with a 'throw'"
"%select{| or 'fallthrough'}1",
(StmtKind, bool))
ERROR(cannot_jump_in_single_value_stmt,none,
"cannot '%0' in '%1' when used as expression",
"cannot use '%0' to transfer control out of '%1' expression",
(StmtKind, StmtKind))
WARNING(effect_marker_on_single_value_stmt,none,
"'%0' has no effect on '%1' expression", (StringRef, StmtKind))
Expand Down
35 changes: 15 additions & 20 deletions include/swift/AST/Stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -1159,36 +1159,29 @@ class alignas(1 << PatternAlignInBits) CaseLabelItem {
/// FallthroughStmt - The keyword "fallthrough".
class FallthroughStmt : public Stmt {
SourceLoc Loc;
CaseStmt *FallthroughSource;
CaseStmt *FallthroughDest;
DeclContext *DC;

public:
FallthroughStmt(SourceLoc Loc, std::optional<bool> implicit = std::nullopt)
FallthroughStmt(SourceLoc Loc, DeclContext *DC,
std::optional<bool> implicit = std::nullopt)
: Stmt(StmtKind::Fallthrough, getDefaultImplicitFlag(implicit, Loc)),
Loc(Loc), FallthroughSource(nullptr), FallthroughDest(nullptr) {}
Loc(Loc), DC(DC) {}
public:
static FallthroughStmt *createParsed(SourceLoc Loc, DeclContext *DC);

SourceLoc getLoc() const { return Loc; }

SourceRange getSourceRange() const { return Loc; }

DeclContext *getDeclContext() const { return DC; }
void setDeclContext(DeclContext *newDC) { DC = newDC; }

/// Get the CaseStmt block from which the fallthrough transfers control.
/// Set during Sema. (May stay null if fallthrough is invalid.)
CaseStmt *getFallthroughSource() const { return FallthroughSource; }
void setFallthroughSource(CaseStmt *C) {
assert(!FallthroughSource && "fallthrough source already set?!");
FallthroughSource = C;
}
/// Returns \c nullptr if the fallthrough is invalid.
CaseStmt *getFallthroughSource() const;

/// Get the CaseStmt block to which the fallthrough transfers control.
/// Set during Sema.
CaseStmt *getFallthroughDest() const {
assert(FallthroughDest && "fallthrough dest is not set until Sema");
return FallthroughDest;
}
void setFallthroughDest(CaseStmt *C) {
assert(!FallthroughDest && "fallthrough dest already set?!");
FallthroughDest = C;
}
/// Returns \c nullptr if the fallthrough is invalid.
CaseStmt *getFallthroughDest() const;

static bool classof(const Stmt *S) {
return S->getKind() == StmtKind::Fallthrough;
Expand Down Expand Up @@ -1613,6 +1606,7 @@ class BreakStmt : public Stmt {
}

DeclContext *getDeclContext() const { return DC; }
void setDeclContext(DeclContext *newDC) { DC = newDC; }

static bool classof(const Stmt *S) {
return S->getKind() == StmtKind::Break;
Expand Down Expand Up @@ -1648,6 +1642,7 @@ class ContinueStmt : public Stmt {
}

DeclContext *getDeclContext() const { return DC; }
void setDeclContext(DeclContext *newDC) { DC = newDC; }

static bool classof(const Stmt *S) {
return S->getKind() == StmtKind::Continue;
Expand Down
23 changes: 23 additions & 0 deletions include/swift/AST/TypeCheckRequests.h
Original file line number Diff line number Diff line change
Expand Up @@ -4162,6 +4162,29 @@ class ContinueTargetRequest
bool isCached() const { return true; }
};

struct FallthroughSourceAndDest {
CaseStmt *Source;
CaseStmt *Dest;
};

/// Lookup the source and destination of a 'fallthrough'.
class FallthroughSourceAndDestRequest
: public SimpleRequest<FallthroughSourceAndDestRequest,
FallthroughSourceAndDest(const FallthroughStmt *),
RequestFlags::Cached> {
public:
using SimpleRequest::SimpleRequest;

private:
friend SimpleRequest;

FallthroughSourceAndDest evaluate(Evaluator &evaluator,
const FallthroughStmt *FS) const;

public:
bool isCached() const { return true; }
};

/// Precheck a ReturnStmt, which involves some initial validation, as well as
/// applying a conversion to a FailStmt if needed.
class PreCheckReturnStmtRequest
Expand Down
3 changes: 3 additions & 0 deletions include/swift/AST/TypeCheckerTypeIDZone.def
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,9 @@ SWIFT_REQUEST(TypeChecker, BreakTargetRequest,
SWIFT_REQUEST(TypeChecker, ContinueTargetRequest,
LabeledStmt *(const ContinueStmt *),
Cached, NoLocationInfo)
SWIFT_REQUEST(TypeChecker, FallthroughSourceAndDestRequest,
FallthroughSourceAndDest(const FallthroughStmt *),
Cached, NoLocationInfo)
SWIFT_REQUEST(TypeChecker, PreCheckReturnStmtRequest,
Stmt *(ReturnStmt *, DeclContext *),
Cached, NoLocationInfo)
Expand Down
6 changes: 3 additions & 3 deletions lib/AST/ASTBridging.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2005,9 +2005,9 @@ BridgedDoCatchStmt BridgedDoCatchStmt_createParsed(
}

BridgedFallthroughStmt
BridgedFallthroughStmt_createParsed(BridgedASTContext cContext,
BridgedSourceLoc cLoc) {
return new (cContext.unbridged()) FallthroughStmt(cLoc.unbridged());
BridgedFallthroughStmt_createParsed(BridgedSourceLoc cLoc,
BridgedDeclContext cDC) {
return FallthroughStmt::createParsed(cLoc.unbridged(), cDC.unbridged());
}

BridgedForEachStmt BridgedForEachStmt_createParsed(
Expand Down
17 changes: 17 additions & 0 deletions lib/AST/Stmt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -989,6 +989,23 @@ LabeledStmt *ContinueStmt::getTarget() const {
return evaluateOrDefault(eval, ContinueTargetRequest{this}, nullptr);
}

FallthroughStmt *FallthroughStmt::createParsed(SourceLoc Loc, DeclContext *DC) {
auto &ctx = DC->getASTContext();
return new (ctx) FallthroughStmt(Loc, DC);
}

CaseStmt *FallthroughStmt::getFallthroughSource() const {
auto &eval = getDeclContext()->getASTContext().evaluator;
return evaluateOrDefault(eval, FallthroughSourceAndDestRequest{this}, {})
.Source;
}

CaseStmt *FallthroughStmt::getFallthroughDest() const {
auto &eval = getDeclContext()->getASTContext().evaluator;
return evaluateOrDefault(eval, FallthroughSourceAndDestRequest{this}, {})
.Dest;
}

SourceLoc swift::extractNearestSourceLoc(const Stmt *S) {
return S->getStartLoc();
}
Expand Down
4 changes: 2 additions & 2 deletions lib/ASTGen/Sources/ASTGen/Stmts.swift
Original file line number Diff line number Diff line change
Expand Up @@ -339,8 +339,8 @@ extension ASTGenVisitor {

func generate(fallThroughStmt node: FallThroughStmtSyntax) -> BridgedFallthroughStmt {
return .createParsed(
self.ctx,
loc: self.generateSourceLoc(node.fallthroughKeyword)
loc: self.generateSourceLoc(node.fallthroughKeyword),
declContext: self.declContext
)
}

Expand Down
4 changes: 2 additions & 2 deletions lib/Parse/ParseStmt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -648,8 +648,8 @@ ParserResult<Stmt> Parser::parseStmt(bool fromASTGen) {
if (LabelInfo) diagnose(LabelInfo.Loc, diag::invalid_label_on_stmt);
if (tryLoc.isValid()) diagnose(tryLoc, diag::try_on_stmt, Tok.getText());

return makeParserResult(
new (Context) FallthroughStmt(consumeToken(tok::kw_fallthrough)));
auto loc = consumeToken(tok::kw_fallthrough);
return makeParserResult(FallthroughStmt::createParsed(loc, CurDeclContext));
}
case tok::pound_assert:
if (LabelInfo) diagnose(LabelInfo.Loc, diag::invalid_label_on_stmt);
Expand Down
2 changes: 1 addition & 1 deletion lib/Sema/CSSyntacticElement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1794,7 +1794,7 @@ class SyntacticElementSolutionApplication
}

ASTNode visitFallthroughStmt(FallthroughStmt *fallthroughStmt) {
if (checkFallthroughStmt(context.getAsDeclContext(), fallthroughStmt))
if (checkFallthroughStmt(fallthroughStmt))
hadError = true;
return fallthroughStmt;
}
Expand Down
2 changes: 1 addition & 1 deletion lib/Sema/MiscDiagnostics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4478,7 +4478,7 @@ class SingleValueStmtUsageChecker final : public ASTWalker {
// default.
Diags.diagnose(branch->getEndLoc(),
diag::single_value_stmt_branch_must_end_in_result,
S->getKind());
S->getKind(), isa<SwitchStmt>(S));
}
break;
}
Expand Down
Loading