Skip to content

Commit df2b3b2

Browse files
committed
[Sema] Introduce PreCheckReturnStmtRequest
Factor out some type-checking logic for ReturnStmt, including the conversion to FailStmt, into a request. We can then invoke this request from both the regular type-checking path, as well as during a pre-check of an expression.
1 parent b3c12ca commit df2b3b2

File tree

4 files changed

+102
-47
lines changed

4 files changed

+102
-47
lines changed

include/swift/AST/TypeCheckRequests.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ class PropertyWrapperInitializerInfo;
5555
struct PropertyWrapperLValueness;
5656
struct PropertyWrapperMutability;
5757
class RequirementRepr;
58+
class ReturnStmt;
5859
class SpecializeAttr;
5960
class TrailingWhereClause;
6061
class TypeAliasDecl;
@@ -3811,6 +3812,24 @@ class ContinueTargetRequest
38113812
bool isCached() const { return true; }
38123813
};
38133814

3815+
/// Precheck a ReturnStmt, which involves some initial validation, as well as
3816+
/// applying a conversion to a FailStmt if needed.
3817+
class PreCheckReturnStmtRequest
3818+
: public SimpleRequest<PreCheckReturnStmtRequest,
3819+
Stmt *(ReturnStmt *, DeclContext *DC),
3820+
RequestFlags::Cached> {
3821+
public:
3822+
using SimpleRequest::SimpleRequest;
3823+
3824+
private:
3825+
friend SimpleRequest;
3826+
3827+
Stmt *evaluate(Evaluator &evaluator, ReturnStmt *RS, DeclContext *DC) const;
3828+
3829+
public:
3830+
bool isCached() const { return true; }
3831+
};
3832+
38143833
class GetTypeWrapperInitializer
38153834
: public SimpleRequest<GetTypeWrapperInitializer,
38163835
ConstructorDecl *(NominalTypeDecl *),

include/swift/AST/TypeCheckerTypeIDZone.def

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,9 @@ SWIFT_REQUEST(TypeChecker, BreakTargetRequest,
443443
SWIFT_REQUEST(TypeChecker, ContinueTargetRequest,
444444
LabeledStmt *(const ContinueStmt *),
445445
Cached, NoLocationInfo)
446+
SWIFT_REQUEST(TypeChecker, PreCheckReturnStmtRequest,
447+
Stmt *(ReturnStmt *, DeclContext *),
448+
Cached, NoLocationInfo)
446449
SWIFT_REQUEST(TypeChecker, GetTypeWrapperInitializer,
447450
ConstructorDecl *(NominalTypeDecl *),
448451
Cached, NoLocationInfo)

lib/Sema/PreCheckExpr.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1300,6 +1300,17 @@ namespace {
13001300
}
13011301

13021302
PreWalkResult<Stmt *> walkToStmtPre(Stmt *stmt) override {
1303+
if (auto *RS = dyn_cast<ReturnStmt>(stmt)) {
1304+
// Pre-check a return statement, which includes potentially turning it
1305+
// into a FailStmt.
1306+
auto &eval = Ctx.evaluator;
1307+
auto *S = evaluateOrDefault(eval, PreCheckReturnStmtRequest{RS, DC},
1308+
nullptr);
1309+
if (!S)
1310+
return Action::Stop();
1311+
1312+
return Action::Continue(S);
1313+
}
13031314
return Action::Continue(stmt);
13041315
}
13051316

lib/Sema/TypeCheckStmt.cpp

Lines changed: 69 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1010,20 +1010,20 @@ class StmtChecker : public StmtVisitor<StmtChecker, Stmt*> {
10101010
Stmt *visitBraceStmt(BraceStmt *BS);
10111011

10121012
Stmt *visitReturnStmt(ReturnStmt *RS) {
1013-
auto TheFunc = AnyFunctionRef::fromDeclContext(DC);
1013+
// First, let's do a pre-check, and bail if the return is completely
1014+
// invalid.
1015+
auto &eval = getASTContext().evaluator;
1016+
auto *S =
1017+
evaluateOrDefault(eval, PreCheckReturnStmtRequest{RS, DC}, nullptr);
1018+
1019+
// We do a cast here as it may have been turned into a FailStmt. We should
1020+
// return that without doing anything else.
1021+
RS = dyn_cast_or_null<ReturnStmt>(S);
1022+
if (!RS)
1023+
return S;
10141024

1015-
if (!TheFunc.has_value()) {
1016-
getASTContext().Diags.diagnose(RS->getReturnLoc(),
1017-
diag::return_invalid_outside_func);
1018-
return nullptr;
1019-
}
1020-
1021-
// If the return is in a defer, then it isn't valid either.
1022-
if (isInDefer()) {
1023-
getASTContext().Diags.diagnose(RS->getReturnLoc(),
1024-
diag::jump_out_of_defer, "return");
1025-
return nullptr;
1026-
}
1025+
auto TheFunc = AnyFunctionRef::fromDeclContext(DC);
1026+
assert(TheFunc && "Should have bailed from pre-check if this is None");
10271027

10281028
Type ResultTy = TheFunc->getBodyResultType();
10291029
if (!ResultTy || ResultTy->hasError())
@@ -1061,40 +1061,6 @@ class StmtChecker : public StmtVisitor<StmtChecker, Stmt*> {
10611061
}
10621062

10631063
Expr *E = RS->getResult();
1064-
1065-
// In an initializer, the only expression allowed is "nil", which indicates
1066-
// failure from a failable initializer.
1067-
if (auto ctor = dyn_cast_or_null<ConstructorDecl>(
1068-
TheFunc->getAbstractFunctionDecl())) {
1069-
// The only valid return expression in an initializer is the literal
1070-
// 'nil'.
1071-
auto nilExpr = dyn_cast<NilLiteralExpr>(E->getSemanticsProvidingExpr());
1072-
if (!nilExpr) {
1073-
getASTContext().Diags.diagnose(RS->getReturnLoc(),
1074-
diag::return_init_non_nil)
1075-
.highlight(E->getSourceRange());
1076-
RS->setResult(nullptr);
1077-
return RS;
1078-
}
1079-
1080-
// "return nil" is only permitted in a failable initializer.
1081-
if (!ctor->isFailable()) {
1082-
getASTContext().Diags.diagnose(RS->getReturnLoc(),
1083-
diag::return_non_failable_init)
1084-
.highlight(E->getSourceRange());
1085-
getASTContext().Diags.diagnose(ctor->getLoc(), diag::make_init_failable,
1086-
ctor->getName())
1087-
.fixItInsertAfter(ctor->getLoc(), "?");
1088-
RS->setResult(nullptr);
1089-
return RS;
1090-
}
1091-
1092-
// Replace the "return nil" with a new 'fail' statement.
1093-
return new (getASTContext()) FailStmt(RS->getReturnLoc(),
1094-
nilExpr->getLoc(),
1095-
RS->isImplicit());
1096-
}
1097-
10981064
TypeCheckExprOptions options = {};
10991065

11001066
if (LeaveBraceStmtBodyUnchecked) {
@@ -1547,6 +1513,62 @@ class StmtChecker : public StmtVisitor<StmtChecker, Stmt*> {
15471513
};
15481514
} // end anonymous namespace
15491515

1516+
Stmt *PreCheckReturnStmtRequest::evaluate(Evaluator &evaluator, ReturnStmt *RS,
1517+
DeclContext *DC) const {
1518+
auto &ctx = DC->getASTContext();
1519+
auto fn = AnyFunctionRef::fromDeclContext(DC);
1520+
1521+
// Not valid outside of a function.
1522+
if (!fn) {
1523+
ctx.Diags.diagnose(RS->getReturnLoc(), diag::return_invalid_outside_func);
1524+
return nullptr;
1525+
}
1526+
1527+
// If the return is in a defer, then it isn't valid either.
1528+
if (isDefer(DC)) {
1529+
ctx.Diags.diagnose(RS->getReturnLoc(), diag::jump_out_of_defer, "return");
1530+
return nullptr;
1531+
}
1532+
1533+
// The rest of the checks only concern return statements with results.
1534+
if (!RS->hasResult())
1535+
return RS;
1536+
1537+
auto *E = RS->getResult();
1538+
1539+
// In an initializer, the only expression allowed is "nil", which indicates
1540+
// failure from a failable initializer.
1541+
if (auto *ctor =
1542+
dyn_cast_or_null<ConstructorDecl>(fn->getAbstractFunctionDecl())) {
1543+
1544+
// The only valid return expression in an initializer is the literal
1545+
// 'nil'.
1546+
auto *nilExpr = dyn_cast<NilLiteralExpr>(E->getSemanticsProvidingExpr());
1547+
if (!nilExpr) {
1548+
ctx.Diags.diagnose(RS->getReturnLoc(), diag::return_init_non_nil)
1549+
.highlight(E->getSourceRange());
1550+
RS->setResult(nullptr);
1551+
return RS;
1552+
}
1553+
1554+
// "return nil" is only permitted in a failable initializer.
1555+
if (!ctor->isFailable()) {
1556+
ctx.Diags.diagnose(RS->getReturnLoc(), diag::return_non_failable_init)
1557+
.highlight(E->getSourceRange());
1558+
ctx.Diags
1559+
.diagnose(ctor->getLoc(), diag::make_init_failable, ctor->getName())
1560+
.fixItInsertAfter(ctor->getLoc(), "?");
1561+
RS->setResult(nullptr);
1562+
return RS;
1563+
}
1564+
1565+
// Replace the "return nil" with a new 'fail' statement.
1566+
return new (ctx)
1567+
FailStmt(RS->getReturnLoc(), nilExpr->getLoc(), RS->isImplicit());
1568+
}
1569+
return RS;
1570+
}
1571+
15501572
static bool isDiscardableType(Type type) {
15511573
return (type->hasError() ||
15521574
type->isUninhabited() ||

0 commit comments

Comments
 (0)