Skip to content

Check thrown error types of for applications/subscripts/property access #69198

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Oct 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions include/swift/AST/ASTScope.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#define SWIFT_AST_AST_SCOPE_H

#include "swift/AST/ASTNode.h"
#include "swift/AST/CatchNode.h"
#include "swift/AST/NameLookup.h"
#include "swift/AST/SimpleRequest.h"
#include "swift/Basic/Compiler.h"
Expand Down Expand Up @@ -85,6 +86,7 @@ class SILGenFunction;

namespace ast_scope {
class ASTScopeImpl;
class BraceStmtScope;
class GenericTypeOrExtensionScope;
class IterableTypeScope;
class TypeAliasScope;
Expand Down Expand Up @@ -211,6 +213,7 @@ class ASTScopeImpl : public ASTAllocated<ASTScopeImpl> {
#pragma mark common queries
public:
virtual NullablePtr<AbstractClosureExpr> getClosureIfClosureScope() const;
virtual NullablePtr<const BraceStmtScope> getAsBraceStmtScope() const;
virtual ASTContext &getASTContext() const;
virtual NullablePtr<Decl> getDeclIfAny() const { return nullptr; };
virtual NullablePtr<Stmt> getStmtIfAny() const { return nullptr; };
Expand Down Expand Up @@ -287,10 +290,18 @@ class ASTScopeImpl : public ASTAllocated<ASTScopeImpl> {
SourceFile *sourceFile, SourceLoc loc,
llvm::function_ref<bool(ASTScope::PotentialMacro)> consume);

static CatchNode lookupCatchNode(ModuleDecl *module, SourceLoc loc);

/// Scopes that cannot bind variables may set this to true to create more
/// compact scope tree in the debug info.
virtual bool ignoreInDebugInfo() const { return false; }

/// If this scope node represents a potential catch node, return body the
/// AST node describing the catch (a function, closure, or do...catch) and
/// the node of it's "body", i.e., the brace statement from which errors
/// thrown will be caught by that node.
virtual std::pair<CatchNode, const BraceStmtScope *> getCatchNodeBody() const;

#pragma mark - - lookup- starting point
private:
static const ASTScopeImpl *findStartingScopeForLookup(SourceFile *,
Expand Down Expand Up @@ -824,6 +835,8 @@ class FunctionBodyScope : public ASTScopeImpl {
Decl *getDecl() const { return decl; }
bool ignoreInDebugInfo() const override { return true; }

std::pair<CatchNode, const BraceStmtScope *> getCatchNodeBody() const override;

protected:
bool lookupLocalsOrMembers(DeclConsumer) const override;

Expand Down Expand Up @@ -1069,6 +1082,8 @@ class ClosureParametersScope final : public ASTScopeImpl {
NullablePtr<AbstractClosureExpr> getClosureIfClosureScope() const override {
return closureExpr;
}
std::pair<CatchNode, const BraceStmtScope *> getCatchNodeBody() const override;

NullablePtr<Expr> getExprIfAny() const override { return closureExpr; }
Expr *getExpr() const { return closureExpr; }
bool ignoreInDebugInfo() const override { return true; }
Expand Down Expand Up @@ -1440,6 +1455,8 @@ class DoCatchStmtScope final : public AbstractStmtScope {
void expandAScopeThatDoesNotCreateANewInsertionPoint(ScopeCreator &);

public:
std::pair<CatchNode, const BraceStmtScope *> getCatchNodeBody() const override;

std::string getClassName() const override;
Stmt *getStmt() const override { return stmt; }
};
Expand Down Expand Up @@ -1648,6 +1665,8 @@ class BraceStmtScope final : public AbstractStmtScope {
NullablePtr<AbstractClosureExpr> parentClosureIfAny() const; // public??
Stmt *getStmt() const override { return stmt; }

NullablePtr<const BraceStmtScope> getAsBraceStmtScope() const override;

protected:
bool lookupLocalsOrMembers(DeclConsumer) const override;
};
Expand Down
43 changes: 43 additions & 0 deletions include/swift/AST/CatchNode.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
//===--- CatchNode.h - An AST node that catches errors -----------*- C++-*-===//
//
// This source file is part of the Swift.org open source project
//
// Copyright (c) 2014 - 2017 Apple Inc. and the Swift project authors
// Licensed under Apache License v2.0 with Runtime Library Exception
//
// See https://swift.org/LICENSE.txt for license information
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
//
//===----------------------------------------------------------------------===//

#ifndef SWIFT_AST_CATCHNODE_H
#define SWIFT_AST_CATCHNODE_H

#include "llvm/ADT/Optional.h"
#include "llvm/ADT/PointerUnion.h"
#include "swift/AST/Decl.h"
#include "swift/AST/Expr.h"
#include "swift/AST/Stmt.h"

namespace swift {

/// An AST node that represents a point where a thrown error can be caught and
/// or rethrown, which includes functions do...catch statements.
class CatchNode: public llvm::PointerUnion<
AbstractFunctionDecl *, AbstractClosureExpr *, DoCatchStmt *
> {
public:
using PointerUnion::PointerUnion;

/// Determine the thrown error type within the region of this catch node
/// where it will catch (and possibly rethrow) errors. All of the errors
/// thrown from within that region will be converted to this error type.
///
/// Returns the thrown error type for a throwing context, or \c llvm::None
/// if this is a non-throwing context.
llvm::Optional<Type> getThrownErrorTypeInContext(ASTContext &ctx) const;
};

} // end namespace swift

#endif // SWIFT_AST_CATCHNODE_H
20 changes: 20 additions & 0 deletions include/swift/AST/NameLookup.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#define SWIFT_AST_NAME_LOOKUP_H

#include "swift/AST/ASTVisitor.h"
#include "swift/AST/CatchNode.h"
#include "swift/AST/GenericSignature.h"
#include "swift/AST/Identifier.h"
#include "swift/AST/Module.h"
Expand Down Expand Up @@ -833,6 +834,25 @@ class ASTScope : public ASTAllocated<ASTScope> {
SourceFile *sourceFile, SourceLoc loc,
llvm::function_ref<bool(PotentialMacro macro)> consume);

/// Look up the scope tree for the nearest point at which an error thrown from
/// this location can be caught or rethrown.
///
/// For example, given this code:
///
/// \code
/// func f() throws {
/// do {
/// try g() // A
/// } catch {
/// throw ErrorWrapper(error) // B
/// }
/// }
/// \endcode
///
/// At the point marked A, the catch node is the enclosing do...catch
/// statement. At the point marked B, the catch node is the function itself.
static CatchNode lookupCatchNode(ModuleDecl *module, SourceLoc loc);

SWIFT_DEBUG_DUMP;
void print(llvm::raw_ostream &) const;
void dumpOneScopeMapLocation(std::pair<unsigned, unsigned>);
Expand Down
59 changes: 59 additions & 0 deletions lib/AST/ASTScope.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ void ASTScope::lookupEnclosingMacroScope(
return ASTScopeImpl::lookupEnclosingMacroScope(sourceFile, loc, body);
}

CatchNode ASTScope::lookupCatchNode(ModuleDecl *module, SourceLoc loc) {
return ASTScopeImpl::lookupCatchNode(module, loc);
}

#if SWIFT_COMPILER_IS_MSVC
#pragma warning(push)
#pragma warning(disable : 4996)
Expand Down Expand Up @@ -97,10 +101,65 @@ NullablePtr<AbstractClosureExpr> BraceStmtScope::parentClosureIfAny() const {
return !getParent() ? nullptr : getParent().get()->getClosureIfClosureScope();
}

NullablePtr<const BraceStmtScope> BraceStmtScope::getAsBraceStmtScope() const {
return this;
}

NullablePtr<AbstractClosureExpr> ASTScopeImpl::getClosureIfClosureScope() const {
return nullptr;
}

NullablePtr<const BraceStmtScope> ASTScopeImpl::getAsBraceStmtScope() const {
return nullptr;
}

std::pair<CatchNode, const BraceStmtScope *>
ASTScopeImpl::getCatchNodeBody() const {
return { nullptr, nullptr };
}

std::pair<CatchNode, const BraceStmtScope *>
ClosureParametersScope::getCatchNodeBody() const {
const BraceStmtScope *body = nullptr;
const auto &children = getChildren();
if (!children.empty()) {
body = children[0]->getAsBraceStmtScope().getPtrOrNull();
assert(body && "Not a brace statement?");
}
return { const_cast<AbstractClosureExpr *>(closureExpr), body };
}

std::pair<CatchNode, const BraceStmtScope *>
FunctionBodyScope::getCatchNodeBody() const {
const BraceStmtScope *body = nullptr;
const auto &children = getChildren();
if (!children.empty()) {
body = children[0]->getAsBraceStmtScope().getPtrOrNull();
assert(body && "Not a brace statement?");
}
return { const_cast<AbstractFunctionDecl *>(decl), body };
}

/// Determine whether this is an empty brace statement, which doesn't have a
/// node associated with it.
static bool isEmptyBraceStmt(Stmt *stmt) {
if (auto braceStmt = dyn_cast_or_null<BraceStmt>(stmt))
return braceStmt->empty();

return false;
}

std::pair<CatchNode, const BraceStmtScope *>
DoCatchStmtScope::getCatchNodeBody() const {
const BraceStmtScope *body = nullptr;
const auto &children = getChildren();
if (!children.empty() && !isEmptyBraceStmt(stmt->getBody())) {
body = children[0]->getAsBraceStmtScope().getPtrOrNull();
assert(body && "Not a brace statement?");
}
return { const_cast<DoCatchStmt *>(stmt), body };
}

SourceManager &ASTScopeImpl::getSourceManager() const {
return getASTContext().SourceMgr;
}
Expand Down
27 changes: 27 additions & 0 deletions lib/AST/ASTScopeLookup.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -712,3 +712,30 @@ void ASTScopeImpl::lookupEnclosingMacroScope(

} while ((scope = scope->getParent().getPtrOrNull()));
}

CatchNode ASTScopeImpl::lookupCatchNode(ModuleDecl *module, SourceLoc loc) {
auto sourceFile = module->getSourceFileContainingLocation(loc);
if (!sourceFile)
return nullptr;

auto *fileScope = sourceFile->getScope().impl;
const auto *innermost = fileScope->findInnermostEnclosingScope(
module, loc, nullptr);
ASTScopeAssert(innermost->getWasExpanded(),
"If looking in a scope, it must have been expanded.");

// Look for a body scope that's the
const BraceStmtScope *innerBodyScope = nullptr;
for (auto scope = innermost; scope; scope = scope->getParent().getPtrOrNull()) {
// If we are at a catch node and in the body of the region from which that
// node catches thrown errors, we have our result.
auto caught = scope->getCatchNodeBody();
if (caught.first && caught.second == innerBodyScope) {
return caught.first;
}

innerBodyScope = scope->getAsBraceStmtScope().getPtrOrNull();
}

return nullptr;
}
43 changes: 41 additions & 2 deletions lib/AST/Decl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -957,14 +957,24 @@ Type AbstractFunctionDecl::getThrownInterfaceType() const {

llvm::Optional<Type>
AbstractFunctionDecl::getEffectiveThrownErrorType() const {
// FIXME: Only getters can have thrown error types right now, and DidSet
// has a cyclic reference if we try to get its interface type here. Find a
// better way to express this.
if (auto accessor = dyn_cast<AccessorDecl>(this)) {
if (accessor->getAccessorKind() != AccessorKind::Get)
return llvm::None;
}

Type interfaceType = getInterfaceType();
if (hasImplicitSelfDecl()) {
if (auto fnType = interfaceType->getAs<AnyFunctionType>())
interfaceType = fnType->getResult();
}

return interfaceType->castTo<AnyFunctionType>()
->getEffectiveThrownErrorType();
if (auto fnType = interfaceType->getAs<AnyFunctionType>())
return fnType->getEffectiveThrownErrorType();

return llvm::None;
}

Expr *AbstractFunctionDecl::getSingleExpressionBody() const {
Expand Down Expand Up @@ -11398,3 +11408,32 @@ MacroDiscriminatorContext::getParentOf(FreestandingMacroExpansion *expansion) {
return getParentOf(
expansion->getPoundLoc(), expansion->getDeclContext());
}

llvm::Optional<Type>
CatchNode::getThrownErrorTypeInContext(ASTContext &ctx) const {
if (auto func = dyn_cast<AbstractFunctionDecl *>()) {
if (auto thrownError = func->getEffectiveThrownErrorType())
return func->mapTypeIntoContext(*thrownError);

return llvm::None;
}

if (auto closure = dyn_cast<AbstractClosureExpr *>()) {
if (closure->getType())
return closure->getEffectiveThrownType();

// FIXME: Should we lazily compute this?
return llvm::None;
}

auto doCatch = get<DoCatchStmt *>();
if (auto thrownError = doCatch->getCaughtErrorType()) {
if (thrownError->isNever())
return llvm::None;

return thrownError;
}

// If we haven't computed the error type yet, do so now.
return ctx.getErrorExistentialType();
}
2 changes: 1 addition & 1 deletion lib/AST/RequirementMachine/RequirementLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -599,7 +599,7 @@ struct InferRequirementsWalker : public TypeWalker {
DifferentiabilityKind::Linear);
}

// Infer that the thrown error type conforms to Error.
// Infer that the thrown error type of a function type conforms to Error.
if (auto thrownError = fnTy->getThrownError()) {
if (auto errorProtocol = ctx.getErrorDecl()) {
addConformanceConstraint(thrownError, errorProtocol);
Expand Down
9 changes: 6 additions & 3 deletions lib/AST/Stmt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -473,12 +473,15 @@ bool DoCatchStmt::isSyntacticallyExhaustive() const {
}

Type DoCatchStmt::getCaughtErrorType() const {
return getCatches()
auto firstPattern = getCatches()
.front()
->getCaseLabelItems()
.front()
.getPattern()
->getType();
.getPattern();
if (firstPattern->hasType())
return firstPattern->getType();

return Type();
}

void LabeledConditionalStmt::setCond(StmtCondition e) {
Expand Down
Loading