@@ -1004,20 +1004,20 @@ class StmtChecker : public StmtVisitor<StmtChecker, Stmt*> {
1004
1004
Stmt *visitBraceStmt (BraceStmt *BS);
1005
1005
1006
1006
Stmt *visitReturnStmt (ReturnStmt *RS) {
1007
- auto TheFunc = AnyFunctionRef::fromDeclContext (DC);
1007
+ // First, let's do a pre-check, and bail if the return is completely
1008
+ // invalid.
1009
+ auto &eval = getASTContext ().evaluator ;
1010
+ auto *S =
1011
+ evaluateOrDefault (eval, PreCheckReturnStmtRequest{RS, DC}, nullptr );
1012
+
1013
+ // We do a cast here as it may have been turned into a FailStmt. We should
1014
+ // return that without doing anything else.
1015
+ RS = dyn_cast_or_null<ReturnStmt>(S);
1016
+ if (!RS)
1017
+ return S;
1008
1018
1009
- if (!TheFunc.has_value ()) {
1010
- getASTContext ().Diags .diagnose (RS->getReturnLoc (),
1011
- diag::return_invalid_outside_func);
1012
- return nullptr ;
1013
- }
1014
-
1015
- // If the return is in a defer, then it isn't valid either.
1016
- if (isInDefer ()) {
1017
- getASTContext ().Diags .diagnose (RS->getReturnLoc (),
1018
- diag::jump_out_of_defer, " return" );
1019
- return nullptr ;
1020
- }
1019
+ auto TheFunc = AnyFunctionRef::fromDeclContext (DC);
1020
+ assert (TheFunc && " Should have bailed from pre-check if this is None" );
1021
1021
1022
1022
Type ResultTy = TheFunc->getBodyResultType ();
1023
1023
if (!ResultTy || ResultTy->hasError ())
@@ -1055,40 +1055,6 @@ class StmtChecker : public StmtVisitor<StmtChecker, Stmt*> {
1055
1055
}
1056
1056
1057
1057
Expr *E = RS->getResult ();
1058
-
1059
- // In an initializer, the only expression allowed is "nil", which indicates
1060
- // failure from a failable initializer.
1061
- if (auto ctor = dyn_cast_or_null<ConstructorDecl>(
1062
- TheFunc->getAbstractFunctionDecl ())) {
1063
- // The only valid return expression in an initializer is the literal
1064
- // 'nil'.
1065
- auto nilExpr = dyn_cast<NilLiteralExpr>(E->getSemanticsProvidingExpr ());
1066
- if (!nilExpr) {
1067
- getASTContext ().Diags .diagnose (RS->getReturnLoc (),
1068
- diag::return_init_non_nil)
1069
- .highlight (E->getSourceRange ());
1070
- RS->setResult (nullptr );
1071
- return RS;
1072
- }
1073
-
1074
- // "return nil" is only permitted in a failable initializer.
1075
- if (!ctor->isFailable ()) {
1076
- getASTContext ().Diags .diagnose (RS->getReturnLoc (),
1077
- diag::return_non_failable_init)
1078
- .highlight (E->getSourceRange ());
1079
- getASTContext ().Diags .diagnose (ctor->getLoc (), diag::make_init_failable,
1080
- ctor->getName ())
1081
- .fixItInsertAfter (ctor->getLoc (), " ?" );
1082
- RS->setResult (nullptr );
1083
- return RS;
1084
- }
1085
-
1086
- // Replace the "return nil" with a new 'fail' statement.
1087
- return new (getASTContext ()) FailStmt (RS->getReturnLoc (),
1088
- nilExpr->getLoc (),
1089
- RS->isImplicit ());
1090
- }
1091
-
1092
1058
TypeCheckExprOptions options = {};
1093
1059
1094
1060
if (LeaveBraceStmtBodyUnchecked) {
@@ -1541,6 +1507,62 @@ class StmtChecker : public StmtVisitor<StmtChecker, Stmt*> {
1541
1507
};
1542
1508
} // end anonymous namespace
1543
1509
1510
+ Stmt *PreCheckReturnStmtRequest::evaluate (Evaluator &evaluator, ReturnStmt *RS,
1511
+ DeclContext *DC) const {
1512
+ auto &ctx = DC->getASTContext ();
1513
+ auto fn = AnyFunctionRef::fromDeclContext (DC);
1514
+
1515
+ // Not valid outside of a function.
1516
+ if (!fn) {
1517
+ ctx.Diags .diagnose (RS->getReturnLoc (), diag::return_invalid_outside_func);
1518
+ return nullptr ;
1519
+ }
1520
+
1521
+ // If the return is in a defer, then it isn't valid either.
1522
+ if (isDefer (DC)) {
1523
+ ctx.Diags .diagnose (RS->getReturnLoc (), diag::jump_out_of_defer, " return" );
1524
+ return nullptr ;
1525
+ }
1526
+
1527
+ // The rest of the checks only concern return statements with results.
1528
+ if (!RS->hasResult ())
1529
+ return RS;
1530
+
1531
+ auto *E = RS->getResult ();
1532
+
1533
+ // In an initializer, the only expression allowed is "nil", which indicates
1534
+ // failure from a failable initializer.
1535
+ if (auto *ctor =
1536
+ dyn_cast_or_null<ConstructorDecl>(fn->getAbstractFunctionDecl ())) {
1537
+
1538
+ // The only valid return expression in an initializer is the literal
1539
+ // 'nil'.
1540
+ auto *nilExpr = dyn_cast<NilLiteralExpr>(E->getSemanticsProvidingExpr ());
1541
+ if (!nilExpr) {
1542
+ ctx.Diags .diagnose (RS->getReturnLoc (), diag::return_init_non_nil)
1543
+ .highlight (E->getSourceRange ());
1544
+ RS->setResult (nullptr );
1545
+ return RS;
1546
+ }
1547
+
1548
+ // "return nil" is only permitted in a failable initializer.
1549
+ if (!ctor->isFailable ()) {
1550
+ ctx.Diags .diagnose (RS->getReturnLoc (), diag::return_non_failable_init)
1551
+ .highlight (E->getSourceRange ());
1552
+ ctx.Diags
1553
+ .diagnose (ctor->getLoc (), diag::make_init_failable, ctor->getName ())
1554
+ .fixItInsertAfter (ctor->getLoc (), " ?" );
1555
+ RS->setResult (nullptr );
1556
+ return RS;
1557
+ }
1558
+
1559
+ // Replace the "return nil" with a new 'fail' statement.
1560
+ return new (ctx)
1561
+ FailStmt (RS->getReturnLoc (), nilExpr->getLoc (), RS->isImplicit ());
1562
+ }
1563
+ return RS;
1564
+ }
1565
+
1544
1566
static bool isDiscardableType (Type type) {
1545
1567
return (type->hasError () ||
1546
1568
type->isUninhabited () ||
0 commit comments