Skip to content

Commit 4e81847

Browse files
committed
[Sema] Introduce PreCheckReturnStmtRequest
Factor out some type-checking logic for ReturnStmt, including the conversion to FailStmt, into a request. We can then invoke this request from both the regular type-checking path, as well as during a pre-check of an expression.
1 parent e75c4bc commit 4e81847

File tree

4 files changed

+102
-47
lines changed

4 files changed

+102
-47
lines changed

include/swift/AST/TypeCheckRequests.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ class PropertyWrapperInitializerInfo;
5555
struct PropertyWrapperLValueness;
5656
struct PropertyWrapperMutability;
5757
class RequirementRepr;
58+
class ReturnStmt;
5859
class SpecializeAttr;
5960
class TrailingWhereClause;
6061
class TypeAliasDecl;
@@ -3749,6 +3750,24 @@ class ContinueTargetRequest
37493750
bool isCached() const { return true; }
37503751
};
37513752

3753+
/// Precheck a ReturnStmt, which involves some initial validation, as well as
3754+
/// applying a conversion to a FailStmt if needed.
3755+
class PreCheckReturnStmtRequest
3756+
: public SimpleRequest<PreCheckReturnStmtRequest,
3757+
Stmt *(ReturnStmt *, DeclContext *DC),
3758+
RequestFlags::Cached> {
3759+
public:
3760+
using SimpleRequest::SimpleRequest;
3761+
3762+
private:
3763+
friend SimpleRequest;
3764+
3765+
Stmt *evaluate(Evaluator &evaluator, ReturnStmt *RS, DeclContext *DC) const;
3766+
3767+
public:
3768+
bool isCached() const { return true; }
3769+
};
3770+
37523771
class GetTypeWrapperInitializer
37533772
: public SimpleRequest<GetTypeWrapperInitializer,
37543773
ConstructorDecl *(NominalTypeDecl *),

include/swift/AST/TypeCheckerTypeIDZone.def

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,9 @@ SWIFT_REQUEST(TypeChecker, BreakTargetRequest,
440440
SWIFT_REQUEST(TypeChecker, ContinueTargetRequest,
441441
LabeledStmt *(const ContinueStmt *),
442442
Cached, NoLocationInfo)
443+
SWIFT_REQUEST(TypeChecker, PreCheckReturnStmtRequest,
444+
Stmt *(ReturnStmt *, DeclContext *),
445+
Cached, NoLocationInfo)
443446
SWIFT_REQUEST(TypeChecker, GetTypeWrapperInitializer,
444447
ConstructorDecl *(NominalTypeDecl *),
445448
Cached, NoLocationInfo)

lib/Sema/PreCheckExpr.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1300,6 +1300,17 @@ namespace {
13001300
}
13011301

13021302
PreWalkResult<Stmt *> walkToStmtPre(Stmt *stmt) override {
1303+
if (auto *RS = dyn_cast<ReturnStmt>(stmt)) {
1304+
// Pre-check a return statement, which includes potentially turning it
1305+
// into a FailStmt.
1306+
auto &eval = Ctx.evaluator;
1307+
auto *S = evaluateOrDefault(eval, PreCheckReturnStmtRequest{RS, DC},
1308+
nullptr);
1309+
if (!S)
1310+
return Action::Stop();
1311+
1312+
return Action::Continue(S);
1313+
}
13031314
return Action::Continue(stmt);
13041315
}
13051316

lib/Sema/TypeCheckStmt.cpp

Lines changed: 69 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1004,20 +1004,20 @@ class StmtChecker : public StmtVisitor<StmtChecker, Stmt*> {
10041004
Stmt *visitBraceStmt(BraceStmt *BS);
10051005

10061006
Stmt *visitReturnStmt(ReturnStmt *RS) {
1007-
auto TheFunc = AnyFunctionRef::fromDeclContext(DC);
1007+
// First, let's do a pre-check, and bail if the return is completely
1008+
// invalid.
1009+
auto &eval = getASTContext().evaluator;
1010+
auto *S =
1011+
evaluateOrDefault(eval, PreCheckReturnStmtRequest{RS, DC}, nullptr);
1012+
1013+
// We do a cast here as it may have been turned into a FailStmt. We should
1014+
// return that without doing anything else.
1015+
RS = dyn_cast_or_null<ReturnStmt>(S);
1016+
if (!RS)
1017+
return S;
10081018

1009-
if (!TheFunc.has_value()) {
1010-
getASTContext().Diags.diagnose(RS->getReturnLoc(),
1011-
diag::return_invalid_outside_func);
1012-
return nullptr;
1013-
}
1014-
1015-
// If the return is in a defer, then it isn't valid either.
1016-
if (isInDefer()) {
1017-
getASTContext().Diags.diagnose(RS->getReturnLoc(),
1018-
diag::jump_out_of_defer, "return");
1019-
return nullptr;
1020-
}
1019+
auto TheFunc = AnyFunctionRef::fromDeclContext(DC);
1020+
assert(TheFunc && "Should have bailed from pre-check if this is None");
10211021

10221022
Type ResultTy = TheFunc->getBodyResultType();
10231023
if (!ResultTy || ResultTy->hasError())
@@ -1055,40 +1055,6 @@ class StmtChecker : public StmtVisitor<StmtChecker, Stmt*> {
10551055
}
10561056

10571057
Expr *E = RS->getResult();
1058-
1059-
// In an initializer, the only expression allowed is "nil", which indicates
1060-
// failure from a failable initializer.
1061-
if (auto ctor = dyn_cast_or_null<ConstructorDecl>(
1062-
TheFunc->getAbstractFunctionDecl())) {
1063-
// The only valid return expression in an initializer is the literal
1064-
// 'nil'.
1065-
auto nilExpr = dyn_cast<NilLiteralExpr>(E->getSemanticsProvidingExpr());
1066-
if (!nilExpr) {
1067-
getASTContext().Diags.diagnose(RS->getReturnLoc(),
1068-
diag::return_init_non_nil)
1069-
.highlight(E->getSourceRange());
1070-
RS->setResult(nullptr);
1071-
return RS;
1072-
}
1073-
1074-
// "return nil" is only permitted in a failable initializer.
1075-
if (!ctor->isFailable()) {
1076-
getASTContext().Diags.diagnose(RS->getReturnLoc(),
1077-
diag::return_non_failable_init)
1078-
.highlight(E->getSourceRange());
1079-
getASTContext().Diags.diagnose(ctor->getLoc(), diag::make_init_failable,
1080-
ctor->getName())
1081-
.fixItInsertAfter(ctor->getLoc(), "?");
1082-
RS->setResult(nullptr);
1083-
return RS;
1084-
}
1085-
1086-
// Replace the "return nil" with a new 'fail' statement.
1087-
return new (getASTContext()) FailStmt(RS->getReturnLoc(),
1088-
nilExpr->getLoc(),
1089-
RS->isImplicit());
1090-
}
1091-
10921058
TypeCheckExprOptions options = {};
10931059

10941060
if (LeaveBraceStmtBodyUnchecked) {
@@ -1541,6 +1507,62 @@ class StmtChecker : public StmtVisitor<StmtChecker, Stmt*> {
15411507
};
15421508
} // end anonymous namespace
15431509

1510+
Stmt *PreCheckReturnStmtRequest::evaluate(Evaluator &evaluator, ReturnStmt *RS,
1511+
DeclContext *DC) const {
1512+
auto &ctx = DC->getASTContext();
1513+
auto fn = AnyFunctionRef::fromDeclContext(DC);
1514+
1515+
// Not valid outside of a function.
1516+
if (!fn) {
1517+
ctx.Diags.diagnose(RS->getReturnLoc(), diag::return_invalid_outside_func);
1518+
return nullptr;
1519+
}
1520+
1521+
// If the return is in a defer, then it isn't valid either.
1522+
if (isDefer(DC)) {
1523+
ctx.Diags.diagnose(RS->getReturnLoc(), diag::jump_out_of_defer, "return");
1524+
return nullptr;
1525+
}
1526+
1527+
// The rest of the checks only concern return statements with results.
1528+
if (!RS->hasResult())
1529+
return RS;
1530+
1531+
auto *E = RS->getResult();
1532+
1533+
// In an initializer, the only expression allowed is "nil", which indicates
1534+
// failure from a failable initializer.
1535+
if (auto *ctor =
1536+
dyn_cast_or_null<ConstructorDecl>(fn->getAbstractFunctionDecl())) {
1537+
1538+
// The only valid return expression in an initializer is the literal
1539+
// 'nil'.
1540+
auto *nilExpr = dyn_cast<NilLiteralExpr>(E->getSemanticsProvidingExpr());
1541+
if (!nilExpr) {
1542+
ctx.Diags.diagnose(RS->getReturnLoc(), diag::return_init_non_nil)
1543+
.highlight(E->getSourceRange());
1544+
RS->setResult(nullptr);
1545+
return RS;
1546+
}
1547+
1548+
// "return nil" is only permitted in a failable initializer.
1549+
if (!ctor->isFailable()) {
1550+
ctx.Diags.diagnose(RS->getReturnLoc(), diag::return_non_failable_init)
1551+
.highlight(E->getSourceRange());
1552+
ctx.Diags
1553+
.diagnose(ctor->getLoc(), diag::make_init_failable, ctor->getName())
1554+
.fixItInsertAfter(ctor->getLoc(), "?");
1555+
RS->setResult(nullptr);
1556+
return RS;
1557+
}
1558+
1559+
// Replace the "return nil" with a new 'fail' statement.
1560+
return new (ctx)
1561+
FailStmt(RS->getReturnLoc(), nilExpr->getLoc(), RS->isImplicit());
1562+
}
1563+
return RS;
1564+
}
1565+
15441566
static bool isDiscardableType(Type type) {
15451567
return (type->hasError() ||
15461568
type->isUninhabited() ||

0 commit comments

Comments
 (0)