Skip to content

Commit a0d1c35

Browse files
authored
Merge pull request #69198 from DougGregor/typed-throws-rethrow-checking
Check thrown error types of for applications/subscripts/property access
2 parents 25061fb + 9ffed75 commit a0d1c35

File tree

13 files changed

+327
-54
lines changed

13 files changed

+327
-54
lines changed

include/swift/AST/ASTScope.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#define SWIFT_AST_AST_SCOPE_H
3030

3131
#include "swift/AST/ASTNode.h"
32+
#include "swift/AST/CatchNode.h"
3233
#include "swift/AST/NameLookup.h"
3334
#include "swift/AST/SimpleRequest.h"
3435
#include "swift/Basic/Compiler.h"
@@ -85,6 +86,7 @@ class SILGenFunction;
8586

8687
namespace ast_scope {
8788
class ASTScopeImpl;
89+
class BraceStmtScope;
8890
class GenericTypeOrExtensionScope;
8991
class IterableTypeScope;
9092
class TypeAliasScope;
@@ -211,6 +213,7 @@ class ASTScopeImpl : public ASTAllocated<ASTScopeImpl> {
211213
#pragma mark common queries
212214
public:
213215
virtual NullablePtr<AbstractClosureExpr> getClosureIfClosureScope() const;
216+
virtual NullablePtr<const BraceStmtScope> getAsBraceStmtScope() const;
214217
virtual ASTContext &getASTContext() const;
215218
virtual NullablePtr<Decl> getDeclIfAny() const { return nullptr; };
216219
virtual NullablePtr<Stmt> getStmtIfAny() const { return nullptr; };
@@ -287,10 +290,18 @@ class ASTScopeImpl : public ASTAllocated<ASTScopeImpl> {
287290
SourceFile *sourceFile, SourceLoc loc,
288291
llvm::function_ref<bool(ASTScope::PotentialMacro)> consume);
289292

293+
static CatchNode lookupCatchNode(ModuleDecl *module, SourceLoc loc);
294+
290295
/// Scopes that cannot bind variables may set this to true to create more
291296
/// compact scope tree in the debug info.
292297
virtual bool ignoreInDebugInfo() const { return false; }
293298

299+
/// If this scope node represents a potential catch node, return body the
300+
/// AST node describing the catch (a function, closure, or do...catch) and
301+
/// the node of it's "body", i.e., the brace statement from which errors
302+
/// thrown will be caught by that node.
303+
virtual std::pair<CatchNode, const BraceStmtScope *> getCatchNodeBody() const;
304+
294305
#pragma mark - - lookup- starting point
295306
private:
296307
static const ASTScopeImpl *findStartingScopeForLookup(SourceFile *,
@@ -824,6 +835,8 @@ class FunctionBodyScope : public ASTScopeImpl {
824835
Decl *getDecl() const { return decl; }
825836
bool ignoreInDebugInfo() const override { return true; }
826837

838+
std::pair<CatchNode, const BraceStmtScope *> getCatchNodeBody() const override;
839+
827840
protected:
828841
bool lookupLocalsOrMembers(DeclConsumer) const override;
829842

@@ -1069,6 +1082,8 @@ class ClosureParametersScope final : public ASTScopeImpl {
10691082
NullablePtr<AbstractClosureExpr> getClosureIfClosureScope() const override {
10701083
return closureExpr;
10711084
}
1085+
std::pair<CatchNode, const BraceStmtScope *> getCatchNodeBody() const override;
1086+
10721087
NullablePtr<Expr> getExprIfAny() const override { return closureExpr; }
10731088
Expr *getExpr() const { return closureExpr; }
10741089
bool ignoreInDebugInfo() const override { return true; }
@@ -1440,6 +1455,8 @@ class DoCatchStmtScope final : public AbstractStmtScope {
14401455
void expandAScopeThatDoesNotCreateANewInsertionPoint(ScopeCreator &);
14411456

14421457
public:
1458+
std::pair<CatchNode, const BraceStmtScope *> getCatchNodeBody() const override;
1459+
14431460
std::string getClassName() const override;
14441461
Stmt *getStmt() const override { return stmt; }
14451462
};
@@ -1648,6 +1665,8 @@ class BraceStmtScope final : public AbstractStmtScope {
16481665
NullablePtr<AbstractClosureExpr> parentClosureIfAny() const; // public??
16491666
Stmt *getStmt() const override { return stmt; }
16501667

1668+
NullablePtr<const BraceStmtScope> getAsBraceStmtScope() const override;
1669+
16511670
protected:
16521671
bool lookupLocalsOrMembers(DeclConsumer) const override;
16531672
};

include/swift/AST/CatchNode.h

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
//===--- CatchNode.h - An AST node that catches errors -----------*- C++-*-===//
2+
//
3+
// This source file is part of the Swift.org open source project
4+
//
5+
// Copyright (c) 2014 - 2017 Apple Inc. and the Swift project authors
6+
// Licensed under Apache License v2.0 with Runtime Library Exception
7+
//
8+
// See https://swift.org/LICENSE.txt for license information
9+
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef SWIFT_AST_CATCHNODE_H
14+
#define SWIFT_AST_CATCHNODE_H
15+
16+
#include "llvm/ADT/Optional.h"
17+
#include "llvm/ADT/PointerUnion.h"
18+
#include "swift/AST/Decl.h"
19+
#include "swift/AST/Expr.h"
20+
#include "swift/AST/Stmt.h"
21+
22+
namespace swift {
23+
24+
/// An AST node that represents a point where a thrown error can be caught and
25+
/// or rethrown, which includes functions do...catch statements.
26+
class CatchNode: public llvm::PointerUnion<
27+
AbstractFunctionDecl *, AbstractClosureExpr *, DoCatchStmt *
28+
> {
29+
public:
30+
using PointerUnion::PointerUnion;
31+
32+
/// Determine the thrown error type within the region of this catch node
33+
/// where it will catch (and possibly rethrow) errors. All of the errors
34+
/// thrown from within that region will be converted to this error type.
35+
///
36+
/// Returns the thrown error type for a throwing context, or \c llvm::None
37+
/// if this is a non-throwing context.
38+
llvm::Optional<Type> getThrownErrorTypeInContext(ASTContext &ctx) const;
39+
};
40+
41+
} // end namespace swift
42+
43+
#endif // SWIFT_AST_CATCHNODE_H

include/swift/AST/NameLookup.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#define SWIFT_AST_NAME_LOOKUP_H
1919

2020
#include "swift/AST/ASTVisitor.h"
21+
#include "swift/AST/CatchNode.h"
2122
#include "swift/AST/GenericSignature.h"
2223
#include "swift/AST/Identifier.h"
2324
#include "swift/AST/Module.h"
@@ -833,6 +834,25 @@ class ASTScope : public ASTAllocated<ASTScope> {
833834
SourceFile *sourceFile, SourceLoc loc,
834835
llvm::function_ref<bool(PotentialMacro macro)> consume);
835836

837+
/// Look up the scope tree for the nearest point at which an error thrown from
838+
/// this location can be caught or rethrown.
839+
///
840+
/// For example, given this code:
841+
///
842+
/// \code
843+
/// func f() throws {
844+
/// do {
845+
/// try g() // A
846+
/// } catch {
847+
/// throw ErrorWrapper(error) // B
848+
/// }
849+
/// }
850+
/// \endcode
851+
///
852+
/// At the point marked A, the catch node is the enclosing do...catch
853+
/// statement. At the point marked B, the catch node is the function itself.
854+
static CatchNode lookupCatchNode(ModuleDecl *module, SourceLoc loc);
855+
836856
SWIFT_DEBUG_DUMP;
837857
void print(llvm::raw_ostream &) const;
838858
void dumpOneScopeMapLocation(std::pair<unsigned, unsigned>);

lib/AST/ASTScope.cpp

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,10 @@ void ASTScope::lookupEnclosingMacroScope(
6666
return ASTScopeImpl::lookupEnclosingMacroScope(sourceFile, loc, body);
6767
}
6868

69+
CatchNode ASTScope::lookupCatchNode(ModuleDecl *module, SourceLoc loc) {
70+
return ASTScopeImpl::lookupCatchNode(module, loc);
71+
}
72+
6973
#if SWIFT_COMPILER_IS_MSVC
7074
#pragma warning(push)
7175
#pragma warning(disable : 4996)
@@ -97,10 +101,65 @@ NullablePtr<AbstractClosureExpr> BraceStmtScope::parentClosureIfAny() const {
97101
return !getParent() ? nullptr : getParent().get()->getClosureIfClosureScope();
98102
}
99103

104+
NullablePtr<const BraceStmtScope> BraceStmtScope::getAsBraceStmtScope() const {
105+
return this;
106+
}
107+
100108
NullablePtr<AbstractClosureExpr> ASTScopeImpl::getClosureIfClosureScope() const {
101109
return nullptr;
102110
}
103111

112+
NullablePtr<const BraceStmtScope> ASTScopeImpl::getAsBraceStmtScope() const {
113+
return nullptr;
114+
}
115+
116+
std::pair<CatchNode, const BraceStmtScope *>
117+
ASTScopeImpl::getCatchNodeBody() const {
118+
return { nullptr, nullptr };
119+
}
120+
121+
std::pair<CatchNode, const BraceStmtScope *>
122+
ClosureParametersScope::getCatchNodeBody() const {
123+
const BraceStmtScope *body = nullptr;
124+
const auto &children = getChildren();
125+
if (!children.empty()) {
126+
body = children[0]->getAsBraceStmtScope().getPtrOrNull();
127+
assert(body && "Not a brace statement?");
128+
}
129+
return { const_cast<AbstractClosureExpr *>(closureExpr), body };
130+
}
131+
132+
std::pair<CatchNode, const BraceStmtScope *>
133+
FunctionBodyScope::getCatchNodeBody() const {
134+
const BraceStmtScope *body = nullptr;
135+
const auto &children = getChildren();
136+
if (!children.empty()) {
137+
body = children[0]->getAsBraceStmtScope().getPtrOrNull();
138+
assert(body && "Not a brace statement?");
139+
}
140+
return { const_cast<AbstractFunctionDecl *>(decl), body };
141+
}
142+
143+
/// Determine whether this is an empty brace statement, which doesn't have a
144+
/// node associated with it.
145+
static bool isEmptyBraceStmt(Stmt *stmt) {
146+
if (auto braceStmt = dyn_cast_or_null<BraceStmt>(stmt))
147+
return braceStmt->empty();
148+
149+
return false;
150+
}
151+
152+
std::pair<CatchNode, const BraceStmtScope *>
153+
DoCatchStmtScope::getCatchNodeBody() const {
154+
const BraceStmtScope *body = nullptr;
155+
const auto &children = getChildren();
156+
if (!children.empty() && !isEmptyBraceStmt(stmt->getBody())) {
157+
body = children[0]->getAsBraceStmtScope().getPtrOrNull();
158+
assert(body && "Not a brace statement?");
159+
}
160+
return { const_cast<DoCatchStmt *>(stmt), body };
161+
}
162+
104163
SourceManager &ASTScopeImpl::getSourceManager() const {
105164
return getASTContext().SourceMgr;
106165
}

lib/AST/ASTScopeLookup.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -712,3 +712,30 @@ void ASTScopeImpl::lookupEnclosingMacroScope(
712712

713713
} while ((scope = scope->getParent().getPtrOrNull()));
714714
}
715+
716+
CatchNode ASTScopeImpl::lookupCatchNode(ModuleDecl *module, SourceLoc loc) {
717+
auto sourceFile = module->getSourceFileContainingLocation(loc);
718+
if (!sourceFile)
719+
return nullptr;
720+
721+
auto *fileScope = sourceFile->getScope().impl;
722+
const auto *innermost = fileScope->findInnermostEnclosingScope(
723+
module, loc, nullptr);
724+
ASTScopeAssert(innermost->getWasExpanded(),
725+
"If looking in a scope, it must have been expanded.");
726+
727+
// Look for a body scope that's the
728+
const BraceStmtScope *innerBodyScope = nullptr;
729+
for (auto scope = innermost; scope; scope = scope->getParent().getPtrOrNull()) {
730+
// If we are at a catch node and in the body of the region from which that
731+
// node catches thrown errors, we have our result.
732+
auto caught = scope->getCatchNodeBody();
733+
if (caught.first && caught.second == innerBodyScope) {
734+
return caught.first;
735+
}
736+
737+
innerBodyScope = scope->getAsBraceStmtScope().getPtrOrNull();
738+
}
739+
740+
return nullptr;
741+
}

lib/AST/Decl.cpp

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -957,14 +957,24 @@ Type AbstractFunctionDecl::getThrownInterfaceType() const {
957957

958958
llvm::Optional<Type>
959959
AbstractFunctionDecl::getEffectiveThrownErrorType() const {
960+
// FIXME: Only getters can have thrown error types right now, and DidSet
961+
// has a cyclic reference if we try to get its interface type here. Find a
962+
// better way to express this.
963+
if (auto accessor = dyn_cast<AccessorDecl>(this)) {
964+
if (accessor->getAccessorKind() != AccessorKind::Get)
965+
return llvm::None;
966+
}
967+
960968
Type interfaceType = getInterfaceType();
961969
if (hasImplicitSelfDecl()) {
962970
if (auto fnType = interfaceType->getAs<AnyFunctionType>())
963971
interfaceType = fnType->getResult();
964972
}
965973

966-
return interfaceType->castTo<AnyFunctionType>()
967-
->getEffectiveThrownErrorType();
974+
if (auto fnType = interfaceType->getAs<AnyFunctionType>())
975+
return fnType->getEffectiveThrownErrorType();
976+
977+
return llvm::None;
968978
}
969979

970980
Expr *AbstractFunctionDecl::getSingleExpressionBody() const {
@@ -11398,3 +11408,32 @@ MacroDiscriminatorContext::getParentOf(FreestandingMacroExpansion *expansion) {
1139811408
return getParentOf(
1139911409
expansion->getPoundLoc(), expansion->getDeclContext());
1140011410
}
11411+
11412+
llvm::Optional<Type>
11413+
CatchNode::getThrownErrorTypeInContext(ASTContext &ctx) const {
11414+
if (auto func = dyn_cast<AbstractFunctionDecl *>()) {
11415+
if (auto thrownError = func->getEffectiveThrownErrorType())
11416+
return func->mapTypeIntoContext(*thrownError);
11417+
11418+
return llvm::None;
11419+
}
11420+
11421+
if (auto closure = dyn_cast<AbstractClosureExpr *>()) {
11422+
if (closure->getType())
11423+
return closure->getEffectiveThrownType();
11424+
11425+
// FIXME: Should we lazily compute this?
11426+
return llvm::None;
11427+
}
11428+
11429+
auto doCatch = get<DoCatchStmt *>();
11430+
if (auto thrownError = doCatch->getCaughtErrorType()) {
11431+
if (thrownError->isNever())
11432+
return llvm::None;
11433+
11434+
return thrownError;
11435+
}
11436+
11437+
// If we haven't computed the error type yet, do so now.
11438+
return ctx.getErrorExistentialType();
11439+
}

lib/AST/RequirementMachine/RequirementLowering.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -599,7 +599,7 @@ struct InferRequirementsWalker : public TypeWalker {
599599
DifferentiabilityKind::Linear);
600600
}
601601

602-
// Infer that the thrown error type conforms to Error.
602+
// Infer that the thrown error type of a function type conforms to Error.
603603
if (auto thrownError = fnTy->getThrownError()) {
604604
if (auto errorProtocol = ctx.getErrorDecl()) {
605605
addConformanceConstraint(thrownError, errorProtocol);

lib/AST/Stmt.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -473,12 +473,15 @@ bool DoCatchStmt::isSyntacticallyExhaustive() const {
473473
}
474474

475475
Type DoCatchStmt::getCaughtErrorType() const {
476-
return getCatches()
476+
auto firstPattern = getCatches()
477477
.front()
478478
->getCaseLabelItems()
479479
.front()
480-
.getPattern()
481-
->getType();
480+
.getPattern();
481+
if (firstPattern->hasType())
482+
return firstPattern->getType();
483+
484+
return Type();
482485
}
483486

484487
void LabeledConditionalStmt::setCond(StmtCondition e) {

0 commit comments

Comments
 (0)