Skip to content

[Typed throws] Cleanups for the caught error type computation #70397

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Dec 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 0 additions & 18 deletions include/swift/AST/AnyFunctionRef.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,24 +124,6 @@ class AnyFunctionRef {
return TheFunction.get<AbstractClosureExpr *>()->getType();
}

Type getThrownErrorType() const {
if (auto *AFD = TheFunction.dyn_cast<AbstractFunctionDecl *>()) {
if (Type thrownError = AFD->getThrownInterfaceType())
return AFD->mapTypeIntoContext(thrownError);

return Type();
}

Type closureType = TheFunction.get<AbstractClosureExpr *>()->getType();
if (!closureType)
return Type();

if (auto closureFnType = closureType->getAs<AnyFunctionType>())
return closureFnType->getThrownError();

return Type();
}

Type getBodyResultType() const {
if (auto *AFD = TheFunction.dyn_cast<AbstractFunctionDecl *>()) {
if (auto *FD = dyn_cast<FuncDecl>(AFD))
Expand Down
8 changes: 8 additions & 0 deletions include/swift/AST/CatchNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#ifndef SWIFT_AST_CATCHNODE_H
#define SWIFT_AST_CATCHNODE_H

#include "llvm/ADT/Hashing.h"
#include "llvm/ADT/Optional.h"
#include "llvm/ADT/PointerUnion.h"
#include "swift/AST/Decl.h"
Expand All @@ -36,8 +37,15 @@ class CatchNode: public llvm::PointerUnion<
/// Returns the thrown error type for a throwing context, or \c llvm::None
/// if this is a non-throwing context.
llvm::Optional<Type> getThrownErrorTypeInContext(DeclContext *dc) const;

friend llvm::hash_code hash_value(CatchNode catchNode) {
using llvm::hash_value;
return hash_value(catchNode.getOpaqueValue());
}
};

void simple_display(llvm::raw_ostream &out, CatchNode catchNode);

} // end namespace swift

#endif // SWIFT_AST_CATCHNODE_H
2 changes: 1 addition & 1 deletion include/swift/AST/Decl.h
Original file line number Diff line number Diff line change
Expand Up @@ -6989,7 +6989,7 @@ void simple_display(llvm::raw_ostream &out, BodyAndFingerprint value);
/// Base class for function-like declarations.
class AbstractFunctionDecl : public GenericContext, public ValueDecl {
friend class NeedsNewVTableEntryRequest;
friend class ThrownTypeRequest;
friend class ExplicitCaughtTypeRequest;

public:
/// records the kind of SILGen-synthesized body this decl represents
Expand Down
13 changes: 4 additions & 9 deletions include/swift/AST/Expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -3987,6 +3987,8 @@ class SerializedAbstractClosureExpr : public SerializedLocalDeclContext {
/// { [weak c] (a : Int) -> Int in a + c!.getFoo() }
/// \endcode
class ClosureExpr : public AbstractClosureExpr {
friend class ExplicitCaughtTypeRequest;

public:
enum class BodyState {
/// The body was parsed, but not ready for type checking because
Expand Down Expand Up @@ -4034,7 +4036,7 @@ class ClosureExpr : public AbstractClosureExpr {
/// The location of the "in", if present.
SourceLoc InLoc;

/// The explcitly-specified thrown type.
/// The explicitly-specified thrown type.
TypeExpr *ThrownType;

/// The explicitly-specified result type.
Expand Down Expand Up @@ -4149,14 +4151,7 @@ class ClosureExpr : public AbstractClosureExpr {
}

/// Retrieve the explicitly-thrown type.
Type getExplicitThrownType() const {
if (ThrownType)
return ThrownType->getInstanceType();

return nullptr;
}

void setExplicitThrownType(Type thrownType);
Type getExplicitThrownType() const;

/// Retrieve the explicitly-thrown type representation.
TypeRepr *getExplicitThrownTypeRepr() const {
Expand Down
10 changes: 5 additions & 5 deletions include/swift/AST/Stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -1387,7 +1387,7 @@ class DoCatchStmt final
: public LabeledStmt,
private llvm::TrailingObjects<DoCatchStmt, CaseStmt *> {
friend TrailingObjects;
friend class DoCatchExplicitThrownTypeRequest;
friend class ExplicitCaughtTypeRequest;

SourceLoc DoLoc;

Expand Down Expand Up @@ -1429,13 +1429,13 @@ class DoCatchStmt final
SourceLoc getStartLoc() const { return getLabelLocOrKeywordLoc(DoLoc); }
SourceLoc getEndLoc() const { return getCatches().back()->getEndLoc(); }

/// Retrieves the type representation for the thrown type.
TypeRepr *getThrownTypeRepr() const {
/// Retrieves the type representation for the caught type.
TypeRepr *getCaughtTypeRepr() const {
return ThrownType.getTypeRepr();
}

// Get the explicitly-specified thrown error type.
Type getExplicitlyThrownType(DeclContext *dc) const;
// Get the explicitly-specified caught error type.
Type getExplicitCaughtType(DeclContext *dc) const;

Stmt *getBody() const { return Body; }
void setBody(Stmt *s) { Body = s; }
Expand Down
38 changes: 12 additions & 26 deletions include/swift/AST/TypeCheckRequests.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "swift/AST/AnyFunctionRef.h"
#include "swift/AST/ASTNode.h"
#include "swift/AST/ASTTypeIDs.h"
#include "swift/AST/CatchNode.h"
#include "swift/AST/Effects.h"
#include "swift/AST/GenericParamList.h"
#include "swift/AST/GenericSignature.h"
Expand Down Expand Up @@ -2283,31 +2284,16 @@ class ParamSpecifierRequest
void cacheResult(ParamSpecifier value) const;
};

/// Determines the explicitly-written thrown result type of a function.
class ThrownTypeRequest
: public SimpleRequest<ThrownTypeRequest,
Type(AbstractFunctionDecl *),
RequestFlags::SeparatelyCached> {
public:
using SimpleRequest::SimpleRequest;

private:
friend SimpleRequest;

// Evaluation.
Type evaluate(Evaluator &evaluator, AbstractFunctionDecl *func) const;

public:
// Separate caching.
bool isCached() const { return true; }
llvm::Optional<Type> getCachedResult() const;
void cacheResult(Type value) const;
};

/// Determines the explicitly-written thrown error type in a do..catch block.
class DoCatchExplicitThrownTypeRequest
: public SimpleRequest<DoCatchExplicitThrownTypeRequest,
Type(DeclContext *, DoCatchStmt *),
/// Determines the explicitly-written caught result type for any catch node,
/// including functions/closures and do..catch statements.
///
/// Returns the caught result type for the catch node, which will be
/// `Never` if no error can be thrown from within the context (e.g., a
/// non-throwing function). Returns a NULL type if the caught error type
/// requires type inference.
class ExplicitCaughtTypeRequest
: public SimpleRequest<ExplicitCaughtTypeRequest,
Type(DeclContext *, CatchNode),
RequestFlags::SeparatelyCached> {
public:
using SimpleRequest::SimpleRequest;
Expand All @@ -2316,7 +2302,7 @@ class DoCatchExplicitThrownTypeRequest
friend SimpleRequest;

// Evaluation.
Type evaluate(Evaluator &evaluator, DeclContext *dc, DoCatchStmt *stmt) const;
Type evaluate(Evaluator &evaluator, DeclContext *dc, CatchNode catchNode) const;

public:
// Separate caching.
Expand Down
7 changes: 2 additions & 5 deletions include/swift/AST/TypeCheckerTypeIDZone.def
Original file line number Diff line number Diff line change
Expand Up @@ -358,11 +358,8 @@ SWIFT_REQUEST(TypeChecker, NeedsNewVTableEntryRequest,
bool(AbstractFunctionDecl *), SeparatelyCached, NoLocationInfo)
SWIFT_REQUEST(TypeChecker, ParamSpecifierRequest,
ParamDecl::Specifier(ParamDecl *), SeparatelyCached, NoLocationInfo)
SWIFT_REQUEST(TypeChecker, ThrownTypeRequest,
Type(AbstractFunctionDecl *), SeparatelyCached, NoLocationInfo)
SWIFT_REQUEST(TypeChecker, DoCatchExplicitThrownTypeRequest,
Type(DeclContext *, DoCatchStmt *), SeparatelyCached,
NoLocationInfo)
SWIFT_REQUEST(TypeChecker, ExplicitCaughtTypeRequest,
Type(DeclContext *, CatchNode), SeparatelyCached, NoLocationInfo)
SWIFT_REQUEST(TypeChecker, ResultTypeRequest,
Type(ValueDecl *), SeparatelyCached, NoLocationInfo)
SWIFT_REQUEST(TypeChecker, AreAllStoredPropertiesDefaultInitableRequest,
Expand Down
115 changes: 110 additions & 5 deletions lib/AST/Decl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -979,9 +979,13 @@ bool Decl::preconcurrency() const {
}

Type AbstractFunctionDecl::getThrownInterfaceType() const {
if (!getThrownTypeRepr())
return ThrownType.getType();

auto mutableThis = const_cast<AbstractFunctionDecl *>(this);
return evaluateOrDefault(
getASTContext().evaluator,
ThrownTypeRequest{const_cast<AbstractFunctionDecl *>(this)},
ExplicitCaughtTypeRequest{mutableThis, mutableThis},
Type());
}

Expand Down Expand Up @@ -11698,11 +11702,19 @@ CatchNode::getThrownErrorTypeInContext(DeclContext *dc) const {
return llvm::None;
}

if (auto closure = dyn_cast<AbstractClosureExpr *>()) {
if (closure->getType())
return closure->getEffectiveThrownType();
if (auto abstractClosure = dyn_cast<AbstractClosureExpr *>()) {
if (abstractClosure->getType())
return abstractClosure->getEffectiveThrownType();

if (auto closure = llvm::dyn_cast<ClosureExpr>(abstractClosure)) {
if (Type thrownType = closure->getExplicitThrownType()) {
if (thrownType->isNever())
return llvm::None;

return thrownType;
}
}

// FIXME: Should we lazily compute this?
return llvm::None;
}

Expand Down Expand Up @@ -11737,3 +11749,96 @@ CatchNode::getThrownErrorTypeInContext(DeclContext *dc) const {

llvm_unreachable("Unhandled catch node kind");
}

void swift::simple_display(llvm::raw_ostream &out, CatchNode catchNode) {
out << "catch node";
}

//----------------------------------------------------------------------------//
// ExplicitCaughtTypeRequest computation.
//----------------------------------------------------------------------------//
bool ExplicitCaughtTypeRequest::isCached() const {
auto catchNode = std::get<1>(getStorage());

// try? and try! never need to be cached.
if (catchNode.is<AnyTryExpr *>())
return false;

// Functions with explicitly-written thrown types need the result cached.
if (auto func = catchNode.dyn_cast<AbstractFunctionDecl *>()) {
return func->ThrownType.getTypeRepr() != nullptr;
}

// Closures with explicitly-written thrown types need the result cached.
if (auto abstractClosure = catchNode.dyn_cast<AbstractClosureExpr *>()) {
if (auto closure = dyn_cast<ClosureExpr>(abstractClosure)) {
return closure->ThrownType != nullptr;
}

return false;
}

// Do..catch with explicitly-written thrown types need the result cached.
if (auto doCatch = catchNode.dyn_cast<DoCatchStmt *>()) {
return doCatch->getThrowsLoc().isValid();
}

llvm_unreachable("Unhandled catch node");
}

llvm::Optional<Type> ExplicitCaughtTypeRequest::getCachedResult() const {
// Map a possibly-null Type to llvm::Optional<Type>.
auto nonnullTypeOrNone = [](Type type) -> llvm::Optional<Type> {
if (type.isNull())
return llvm::None;

return type;
};

auto catchNode = std::get<1>(getStorage());

if (auto func = catchNode.dyn_cast<AbstractFunctionDecl *>()) {
return nonnullTypeOrNone(func->ThrownType.getType());
}

if (auto abstractClosure = catchNode.dyn_cast<AbstractClosureExpr *>()) {
auto closure = cast<ClosureExpr>(abstractClosure);
if (closure->ThrownType) {
return nonnullTypeOrNone(closure->ThrownType->getInstanceType());
}

return llvm::None;
}

if (auto doCatch = catchNode.dyn_cast<DoCatchStmt *>()) {
return nonnullTypeOrNone(doCatch->ThrownType.getType());
}

llvm_unreachable("Unhandled catch node");
}

void ExplicitCaughtTypeRequest::cacheResult(Type type) const {
auto catchNode = std::get<1>(getStorage());

if (auto func = catchNode.dyn_cast<AbstractFunctionDecl *>()) {
func->ThrownType.setType(type);
return;
}

if (auto abstractClosure = catchNode.dyn_cast<AbstractClosureExpr *>()) {
auto closure = cast<ClosureExpr>(abstractClosure);
if (closure->ThrownType)
closure->ThrownType->setType(MetatypeType::get(type));
else
closure->ThrownType =
TypeExpr::createImplicit(type, type->getASTContext());
return;
}

if (auto doCatch = catchNode.dyn_cast<DoCatchStmt *>()) {
doCatch->ThrownType.setType(type);
return;
}

llvm_unreachable("Unhandled catch node");
}
19 changes: 12 additions & 7 deletions lib/AST/Expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1937,8 +1937,10 @@ Type AbstractClosureExpr::getResultType(
}

llvm::Optional<Type> AbstractClosureExpr::getEffectiveThrownType() const {
return getType()->castTo<AnyFunctionType>()
->getEffectiveThrownErrorType();
if (auto fnType = getType()->getAs<AnyFunctionType>())
return fnType->getEffectiveThrownErrorType();

return llvm::None;
}

bool AbstractClosureExpr::isBodyThrowing() const {
Expand Down Expand Up @@ -2046,11 +2048,14 @@ bool ClosureExpr::hasEmptyBody() const {
return getBody()->empty();
}

void ClosureExpr::setExplicitThrownType(Type thrownType) {
assert(thrownType && !thrownType->hasTypeVariable() &&
!thrownType->hasPlaceholder());
assert(ThrownType);
ThrownType->setType(MetatypeType::get(thrownType));
Type ClosureExpr::getExplicitThrownType() const {
if (getThrowsLoc().isInvalid())
return Type();

ASTContext &ctx = getASTContext();
auto mutableThis = const_cast<ClosureExpr *>(this);
ExplicitCaughtTypeRequest request{mutableThis, mutableThis};
return evaluateOrDefault(ctx.evaluator, request, Type());
}

void ClosureExpr::setExplicitResultType(Type ty) {
Expand Down
6 changes: 3 additions & 3 deletions lib/AST/Stmt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -478,15 +478,15 @@ bool DoCatchStmt::isSyntacticallyExhaustive() const {
return false;
}

Type DoCatchStmt::getExplicitlyThrownType(DeclContext *dc) const {
Type DoCatchStmt::getExplicitCaughtType(DeclContext *dc) const {
ASTContext &ctx = dc->getASTContext();
DoCatchExplicitThrownTypeRequest request{dc, const_cast<DoCatchStmt *>(this)};
ExplicitCaughtTypeRequest request{dc, const_cast<DoCatchStmt *>(this)};
return evaluateOrDefault(ctx.evaluator, request, Type());
}

Type DoCatchStmt::getCaughtErrorType(DeclContext *dc) const {
// Check for an explicitly-specified error type.
if (Type explicitError = getExplicitlyThrownType(dc))
if (Type explicitError = getExplicitCaughtType(dc))
return explicitError;

auto firstPattern = getCatches()
Expand Down
Loading