Skip to content

Commit 7e9abe7

Browse files
committed
[Typed throws] Implement support for do throws(...) syntax
During the review of SE-0413, typed throws, the notion of a `do throws` syntax for `do..catch` blocks came up. Implement that syntax and semantics, as a way to explicitly specify the type of error that is thrown from the `do` body in `do..catch` statement.
1 parent 5aa29c0 commit 7e9abe7

15 files changed

+194
-24
lines changed

include/swift/AST/CatchNode.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ class CatchNode: public llvm::PointerUnion<
3535
///
3636
/// Returns the thrown error type for a throwing context, or \c llvm::None
3737
/// if this is a non-throwing context.
38-
llvm::Optional<Type> getThrownErrorTypeInContext(ASTContext &ctx) const;
38+
llvm::Optional<Type> getThrownErrorTypeInContext(DeclContext *dc) const;
3939
};
4040

4141
} // end namespace swift

include/swift/AST/DiagnosticsParse.def

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1196,6 +1196,9 @@ ERROR(expected_catch_where_expr,PointsToFirstBadToken,
11961196
ERROR(docatch_not_trycatch,PointsToFirstBadToken,
11971197
"the 'do' keyword is used to specify a 'catch' region",
11981198
())
1199+
ERROR(do_throws_without_catch,none,
1200+
"a 'do' statement with a 'throws' clause must have at least one 'catch'",
1201+
())
11991202

12001203
// C-Style For Stmt
12011204
ERROR(c_style_for_stmt_removed,none,

include/swift/AST/Stmt.h

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "swift/AST/ConcreteDeclRef.h"
2525
#include "swift/AST/IfConfigClause.h"
2626
#include "swift/AST/TypeAlignments.h"
27+
#include "swift/AST/TypeLoc.h"
2728
#include "swift/AST/ThrownErrorDestination.h"
2829
#include "swift/Basic/Debug.h"
2930
#include "swift/Basic/NullablePtr.h"
@@ -1381,16 +1382,25 @@ class DoCatchStmt final
13811382
: public LabeledStmt,
13821383
private llvm::TrailingObjects<DoCatchStmt, CaseStmt *> {
13831384
friend TrailingObjects;
1385+
friend class DoCatchExplicitThrownTypeRequest;
13841386

13851387
SourceLoc DoLoc;
1388+
1389+
/// Location of the 'throws' token.
1390+
SourceLoc ThrowsLoc;
1391+
1392+
/// The error type that is being thrown.
1393+
TypeLoc ThrownType;
1394+
13861395
Stmt *Body;
13871396
ThrownErrorDestination RethrowDest;
13881397

1389-
DoCatchStmt(LabeledStmtInfo labelInfo, SourceLoc doLoc, Stmt *body,
1398+
DoCatchStmt(LabeledStmtInfo labelInfo, SourceLoc doLoc,
1399+
SourceLoc throwsLoc, TypeLoc thrownType, Stmt *body,
13901400
ArrayRef<CaseStmt *> catches, llvm::Optional<bool> implicit)
13911401
: LabeledStmt(StmtKind::DoCatch, getDefaultImplicitFlag(implicit, doLoc),
13921402
labelInfo),
1393-
DoLoc(doLoc), Body(body) {
1403+
DoLoc(doLoc), ThrowsLoc(throwsLoc), ThrownType(thrownType), Body(body) {
13941404
Bits.DoCatchStmt.NumCatches = catches.size();
13951405
std::uninitialized_copy(catches.begin(), catches.end(),
13961406
getTrailingObjects<CaseStmt *>());
@@ -1400,15 +1410,28 @@ class DoCatchStmt final
14001410

14011411
public:
14021412
static DoCatchStmt *create(ASTContext &ctx, LabeledStmtInfo labelInfo,
1403-
SourceLoc doLoc, Stmt *body,
1413+
SourceLoc doLoc,
1414+
SourceLoc throwsLoc, TypeLoc thrownType,
1415+
Stmt *body,
14041416
ArrayRef<CaseStmt *> catches,
14051417
llvm::Optional<bool> implicit = llvm::None);
14061418

14071419
SourceLoc getDoLoc() const { return DoLoc; }
14081420

1421+
/// Retrieve the location of the 'throws' keyword, if present.
1422+
SourceLoc getThrowsLoc() const { return ThrowsLoc; }
1423+
14091424
SourceLoc getStartLoc() const { return getLabelLocOrKeywordLoc(DoLoc); }
14101425
SourceLoc getEndLoc() const { return getCatches().back()->getEndLoc(); }
14111426

1427+
/// Retrieves the type representation for the thrown type.
1428+
TypeRepr *getThrownTypeRepr() const {
1429+
return ThrownType.getTypeRepr();
1430+
}
1431+
1432+
// Get the explicitly-specified thrown error type.
1433+
Type getExplicitlyThrownType(DeclContext *dc) const;
1434+
14121435
Stmt *getBody() const { return Body; }
14131436
void setBody(Stmt *s) { Body = s; }
14141437

@@ -1433,7 +1456,7 @@ class DoCatchStmt final
14331456
// and caught by the various 'catch' clauses. If this the catch clauses
14341457
// aren't exhausive, this is also the type of the error that is implicitly
14351458
// rethrown.
1436-
Type getCaughtErrorType() const;
1459+
Type getCaughtErrorType(DeclContext *dc) const;
14371460

14381461
/// Retrieves the rethrown error and its conversion to the error type
14391462
/// expected by the enclosing context.

include/swift/AST/TypeCheckRequests.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ class ContextualPattern;
4646
class ContinueStmt;
4747
class DefaultArgumentExpr;
4848
class DefaultArgumentType;
49+
class DoCatchStmt;
4950
struct ExternalMacroDefinition;
5051
class ClosureExpr;
5152
class GenericParamList;
@@ -2303,6 +2304,27 @@ class ThrownTypeRequest
23032304
void cacheResult(Type value) const;
23042305
};
23052306

2307+
/// Determines the explicitly-written thrown error type in a do..catch block.
2308+
class DoCatchExplicitThrownTypeRequest
2309+
: public SimpleRequest<DoCatchExplicitThrownTypeRequest,
2310+
Type(DeclContext *, DoCatchStmt *),
2311+
RequestFlags::SeparatelyCached> {
2312+
public:
2313+
using SimpleRequest::SimpleRequest;
2314+
2315+
private:
2316+
friend SimpleRequest;
2317+
2318+
// Evaluation.
2319+
Type evaluate(Evaluator &evaluator, DeclContext *dc, DoCatchStmt *stmt) const;
2320+
2321+
public:
2322+
// Separate caching.
2323+
bool isCached() const;
2324+
llvm::Optional<Type> getCachedResult() const;
2325+
void cacheResult(Type value) const;
2326+
};
2327+
23062328
/// Determines the result type of a function or element type of a subscript.
23072329
class ResultTypeRequest
23082330
: public SimpleRequest<ResultTypeRequest,

include/swift/AST/TypeCheckerTypeIDZone.def

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,9 @@ SWIFT_REQUEST(TypeChecker, ParamSpecifierRequest,
360360
ParamDecl::Specifier(ParamDecl *), SeparatelyCached, NoLocationInfo)
361361
SWIFT_REQUEST(TypeChecker, ThrownTypeRequest,
362362
Type(AbstractFunctionDecl *), SeparatelyCached, NoLocationInfo)
363+
SWIFT_REQUEST(TypeChecker, DoCatchExplicitThrownTypeRequest,
364+
Type(DeclContext *, DoCatchStmt *), SeparatelyCached,
365+
NoLocationInfo)
363366
SWIFT_REQUEST(TypeChecker, ResultTypeRequest,
364367
Type(ValueDecl *), SeparatelyCached, NoLocationInfo)
365368
SWIFT_REQUEST(TypeChecker, AreAllStoredPropertiesDefaultInitableRequest,

lib/AST/Decl.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11634,7 +11634,7 @@ MacroDiscriminatorContext::getParentOf(FreestandingMacroExpansion *expansion) {
1163411634
}
1163511635

1163611636
llvm::Optional<Type>
11637-
CatchNode::getThrownErrorTypeInContext(ASTContext &ctx) const {
11637+
CatchNode::getThrownErrorTypeInContext(DeclContext *dc) const {
1163811638
if (auto func = dyn_cast<AbstractFunctionDecl *>()) {
1163911639
if (auto thrownError = func->getEffectiveThrownErrorType())
1164011640
return func->mapTypeIntoContext(*thrownError);
@@ -11651,13 +11651,13 @@ CatchNode::getThrownErrorTypeInContext(ASTContext &ctx) const {
1165111651
}
1165211652

1165311653
auto doCatch = get<DoCatchStmt *>();
11654-
if (auto thrownError = doCatch->getCaughtErrorType()) {
11654+
if (auto thrownError = doCatch->getCaughtErrorType(dc)) {
1165511655
if (thrownError->isNever())
1165611656
return llvm::None;
1165711657

1165811658
return thrownError;
1165911659
}
1166011660

1166111661
// If we haven't computed the error type yet, do so now.
11662-
return ctx.getErrorExistentialType();
11662+
return dc->getASTContext().getErrorExistentialType();
1166311663
}

lib/AST/Stmt.cpp

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -450,12 +450,15 @@ Expr *ForEachStmt::getTypeCheckedSequence() const {
450450
}
451451

452452
DoCatchStmt *DoCatchStmt::create(ASTContext &ctx, LabeledStmtInfo labelInfo,
453-
SourceLoc doLoc, Stmt *body,
453+
SourceLoc doLoc,
454+
SourceLoc throwsLoc, TypeLoc thrownType,
455+
Stmt *body,
454456
ArrayRef<CaseStmt *> catches,
455457
llvm::Optional<bool> implicit) {
456458
void *mem = ctx.Allocate(totalSizeToAlloc<CaseStmt *>(catches.size()),
457459
alignof(DoCatchStmt));
458-
return ::new (mem) DoCatchStmt(labelInfo, doLoc, body, catches, implicit);
460+
return ::new (mem) DoCatchStmt(labelInfo, doLoc, throwsLoc, thrownType, body,
461+
catches, implicit);
459462
}
460463

461464
bool CaseLabelItem::isSyntacticallyExhaustive() const {
@@ -472,7 +475,17 @@ bool DoCatchStmt::isSyntacticallyExhaustive() const {
472475
return false;
473476
}
474477

475-
Type DoCatchStmt::getCaughtErrorType() const {
478+
Type DoCatchStmt::getExplicitlyThrownType(DeclContext *dc) const {
479+
ASTContext &ctx = dc->getASTContext();
480+
DoCatchExplicitThrownTypeRequest request{dc, const_cast<DoCatchStmt *>(this)};
481+
return evaluateOrDefault(ctx.evaluator, request, Type());
482+
}
483+
484+
Type DoCatchStmt::getCaughtErrorType(DeclContext *dc) const {
485+
// Check for an explicitly-specified error type.
486+
if (Type explicitError = getExplicitlyThrownType(dc))
487+
return explicitError;
488+
476489
auto firstPattern = getCatches()
477490
.front()
478491
->getCaseLabelItems()

lib/AST/TypeCheckRequests.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -969,6 +969,29 @@ void ThrownTypeRequest::cacheResult(Type type) const {
969969
func->ThrownType.setType(type);
970970
}
971971

972+
//----------------------------------------------------------------------------//
973+
// DoCatchExplicitThrownTypeRequest computation.
974+
//----------------------------------------------------------------------------//
975+
976+
bool DoCatchExplicitThrownTypeRequest::isCached() const {
977+
auto *const stmt = std::get<1>(getStorage());
978+
return stmt->getThrowsLoc().isValid();
979+
}
980+
981+
llvm::Optional<Type> DoCatchExplicitThrownTypeRequest::getCachedResult() const {
982+
auto *const stmt = std::get<1>(getStorage());
983+
Type thrownType = stmt->ThrownType.getType();
984+
if (thrownType.isNull())
985+
return llvm::None;
986+
987+
return thrownType;
988+
}
989+
990+
void DoCatchExplicitThrownTypeRequest::cacheResult(Type type) const {
991+
auto *const stmt = std::get<1>(getStorage());
992+
stmt->ThrownType.setType(type);
993+
}
994+
972995
//----------------------------------------------------------------------------//
973996
// ResultTypeRequest computation.
974997
//----------------------------------------------------------------------------//

lib/Parse/ParseStmt.cpp

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2191,8 +2191,8 @@ ParserResult<Stmt> Parser::parseStmtRepeat(LabeledStmtInfo labelInfo) {
21912191

21922192
///
21932193
/// stmt-do:
2194-
/// (identifier ':')? 'do' stmt-brace
2195-
/// (identifier ':')? 'do' stmt-brace stmt-catch+
2194+
/// (identifier ':')? 'do' throws-clause? stmt-brace
2195+
/// (identifier ':')? 'do' throws-clause? stmt-brace stmt-catch+
21962196
ParserResult<Stmt> Parser::parseStmtDo(LabeledStmtInfo labelInfo,
21972197
bool shouldSkipDoTokenConsume) {
21982198
SourceLoc doLoc;
@@ -2205,6 +2205,25 @@ ParserResult<Stmt> Parser::parseStmtDo(LabeledStmtInfo labelInfo,
22052205

22062206
ParserStatus status;
22072207

2208+
// Parse the optional 'throws' clause.
2209+
SourceLoc throwsLoc;
2210+
TypeRepr *thrownType = nullptr;
2211+
if (consumeIf(tok::kw_throws, throwsLoc)) {
2212+
// Parse the thrown error type.
2213+
SourceLoc lParenLoc;
2214+
if (consumeIf(tok::l_paren, lParenLoc)) {
2215+
ParserResult<TypeRepr> parsedThrownTy =
2216+
parseType(diag::expected_thrown_error_type);
2217+
thrownType = parsedThrownTy.getPtrOrNull();
2218+
status |= parsedThrownTy;
2219+
2220+
SourceLoc rParenLoc;
2221+
parseMatchingToken(
2222+
tok::r_paren, rParenLoc,
2223+
diag::expected_rparen_after_thrown_error_type, lParenLoc);
2224+
}
2225+
}
2226+
22082227
ParserResult<BraceStmt> body =
22092228
parseBraceItemList(diag::expected_lbrace_after_do);
22102229
status |= body;
@@ -2236,7 +2255,12 @@ ParserResult<Stmt> Parser::parseStmtDo(LabeledStmtInfo labelInfo,
22362255
}
22372256

22382257
return makeParserResult(status,
2239-
DoCatchStmt::create(Context, labelInfo, doLoc, body.get(), allClauses));
2258+
DoCatchStmt::create(Context, labelInfo, doLoc, throwsLoc, thrownType,
2259+
body.get(), allClauses));
2260+
}
2261+
2262+
if (throwsLoc.isValid()) {
2263+
diagnose(throwsLoc, diag::do_throws_without_catch);
22402264
}
22412265

22422266
// If we dont see a 'while' or see a 'while' that starts

lib/SILGen/SILGenStmt.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1117,7 +1117,7 @@ void StmtEmitter::visitDoStmt(DoStmt *S) {
11171117
}
11181118

11191119
void StmtEmitter::visitDoCatchStmt(DoCatchStmt *S) {
1120-
Type formalExnType = S->getCaughtErrorType();
1120+
Type formalExnType = S->getCaughtErrorType(SGF.FunctionDC);
11211121
auto &exnTL = SGF.getTypeLowering(formalExnType);
11221122

11231123
SILValue exnArg;

lib/Sema/TypeCheckEffects.cpp

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2751,9 +2751,10 @@ class CheckEffectsCoverage : public EffectsHandlingWalker<CheckEffectsCoverage>
27512751
/// Retrieve the type of the error that can be caught when an error is
27522752
/// thrown from the given location.
27532753
Type getCaughtErrorTypeAt(SourceLoc loc) {
2754-
auto module = CurContext.getDeclContext()->getParentModule();
2754+
auto dc = CurContext.getDeclContext();
2755+
auto module = dc->getParentModule();
27552756
if (CatchNode catchNode = ASTScope::lookupCatchNode(module, loc)) {
2756-
if (auto caughtType = catchNode.getThrownErrorTypeInContext(Ctx))
2757+
if (auto caughtType = catchNode.getThrownErrorTypeInContext(dc))
27572758
return *caughtType;
27582759
}
27592760

@@ -2880,7 +2881,8 @@ class CheckEffectsCoverage : public EffectsHandlingWalker<CheckEffectsCoverage>
28802881
// specialized diagnostic about non-exhaustive catches.
28812882
if (!CurContext.handlesThrows(ConditionalEffectKind::Conditional)) {
28822883
CurContext.setNonExhaustiveCatch(true);
2883-
} else if (Type rethrownErrorType = S->getCaughtErrorType()) {
2884+
} else if (Type rethrownErrorType =
2885+
S->getCaughtErrorType(CurContext.getDeclContext())) {
28842886
// We're implicitly rethrowing the error out of this do..catch, so make
28852887
// sure that we can throw an error of this type out of this context.
28862888
auto catches = S->getCatches();
@@ -3510,15 +3512,23 @@ llvm::Optional<Type> TypeChecker::canThrow(ASTContext &ctx, Expr *expr) {
35103512
return classification.getThrownError();
35113513
}
35123514

3513-
Type TypeChecker::catchErrorType(ASTContext &ctx, DoCatchStmt *stmt) {
3515+
Type TypeChecker::catchErrorType(DeclContext *dc, DoCatchStmt *stmt) {
3516+
ASTContext &ctx = dc->getASTContext();
3517+
35143518
// When typed throws is disabled, this is always "any Error".
35153519
// FIXME: When we distinguish "precise" typed throws from normal typed
35163520
// throws, we'll be able to compute a more narrow catch error type in some
35173521
// case, e.g., from a `try` but not a `throws`.
35183522
if (!ctx.LangOpts.hasFeature(Feature::TypedThrows))
35193523
return ctx.getErrorExistentialType();
35203524

3521-
// Classify the throwing behavior of the "do" body.
3525+
// If the do..catch statement explicitly specifies that it throws, use
3526+
// that type.
3527+
if (Type explicitError = stmt->getExplicitlyThrownType(dc)) {
3528+
return explicitError;
3529+
}
3530+
3531+
// Otherwise, infer the thrown error type from the "do" body.
35223532
ApplyClassifier classifier(ctx);
35233533
Classification classification = classifier.classifyStmt(
35243534
stmt->getBody(), EffectKind::Throws);

lib/Sema/TypeCheckStmt.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1195,8 +1195,7 @@ class StmtChecker : public StmtVisitor<StmtChecker, Stmt*> {
11951195
DC->getParentModule(), TS->getThrowLoc());
11961196
Type errorType;
11971197
if (catchNode) {
1198-
errorType = catchNode.getThrownErrorTypeInContext(getASTContext())
1199-
.value_or(Type());
1198+
errorType = catchNode.getThrownErrorTypeInContext(DC).value_or(Type());
12001199
}
12011200

12021201
// If there was no error type, use 'any Error'. We'll check it later.
@@ -1679,7 +1678,7 @@ class StmtChecker : public StmtVisitor<StmtChecker, Stmt*> {
16791678
// Do-catch statements always limit exhaustivity checks.
16801679
bool limitExhaustivityChecks = true;
16811680

1682-
Type caughtErrorType = TypeChecker::catchErrorType(Ctx, S);
1681+
Type caughtErrorType = TypeChecker::catchErrorType(DC, S);
16831682
auto catches = S->getCatches();
16841683
checkSiblingCaseStmts(catches.begin(), catches.end(),
16851684
CaseParentKind::DoCatch, limitExhaustivityChecks,

lib/Sema/TypeCheckType.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5626,3 +5626,29 @@ Type CustomAttrTypeRequest::evaluate(Evaluator &eval, CustomAttr *attr,
56265626

56275627
return type;
56285628
}
5629+
5630+
Type DoCatchExplicitThrownTypeRequest::evaluate(
5631+
Evaluator &evaluator, DeclContext *dc, DoCatchStmt *stmt
5632+
) const {
5633+
if (stmt->getThrowsLoc().isInvalid())
5634+
return Type();
5635+
5636+
// If typed throws is not enabled, complain.
5637+
ASTContext &ctx = dc->getASTContext();
5638+
if (!ctx.LangOpts.hasFeature(Feature::TypedThrows)) {
5639+
ctx.Diags.diagnose(stmt->getThrowsLoc(), diag::experimental_typed_throws);
5640+
return Type();
5641+
}
5642+
5643+
auto typeRepr = stmt->getThrownTypeRepr();
5644+
5645+
// If there is no explicitly-specified thrown error type, it's 'any Error'.
5646+
if (!typeRepr) {
5647+
return ctx.getErrorExistentialType();
5648+
}
5649+
5650+
return TypeResolution::resolveContextualType(
5651+
typeRepr, dc, TypeResolutionOptions(TypeResolverContext::None),
5652+
/*unboundTyOpener*/ nullptr, PlaceholderType::get,
5653+
/*packElementOpener*/ nullptr);
5654+
}

lib/Sema/TypeChecker.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1176,7 +1176,7 @@ llvm::Optional<Type> canThrow(ASTContext &ctx, Expr *expr);
11761176
///
11771177
/// The error type is used in the catch clauses and, for a nonexhausive
11781178
/// do-catch, is implicitly rethrown out of the do...catch block.
1179-
Type catchErrorType(ASTContext &ctx, DoCatchStmt *stmt);
1179+
Type catchErrorType(DeclContext *dc, DoCatchStmt *stmt);
11801180

11811181
/// Given two error types, merge them into the "union" of both error types
11821182
/// that is a supertype of both error types.

0 commit comments

Comments
 (0)