Skip to content

Commit 119a719

Browse files
committed
[Typed throws] Teach associated type inference to infer from thrown errors
When comparing a requirement that uses typed throws and uses an associated type for the thrown error type against a potential witness, infer the associated type from the thrown error of the witness---whether explicitly specified, untyped throws (`any Error`), or non-throwing (`Never`).
1 parent 5380298 commit 119a719

15 files changed

+88
-32
lines changed

include/swift/AST/Decl.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7123,7 +7123,7 @@ class AbstractFunctionDecl : public GenericContext, public ValueDecl {
71237123
///
71247124
/// Functions with untyped throws will produce "any Error", functions that
71257125
/// cannot throw or are specified to throw "Never" will return llvm::None.
7126-
llvm::Optional<Type> getEffectiveThrownInterfaceType() const;
7126+
llvm::Optional<Type> getEffectiveThrownErrorType() const;
71277127

71287128
/// Returns if the function throws or is async.
71297129
bool hasEffect(EffectKind kind) const;

include/swift/AST/TypeMatcher.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -397,6 +397,15 @@ class TypeMatcher {
397397
return false;
398398
}
399399

400+
// If requested, compare the thrown error types.
401+
Type thrownError1 = firstFunc->getEffectiveThrownErrorTypeOrNever();
402+
Type thrownError2 = secondFunc->getEffectiveThrownErrorTypeOrNever();
403+
if (Matcher.asDerived().considerThrownErrorTypes(thrownError1,
404+
thrownError2) &&
405+
!this->visit(thrownError1->getCanonicalType(),
406+
thrownError2, thrownError1))
407+
return false;
408+
400409
return this->visit(firstFunc.getResult(), secondFunc->getResult(),
401410
sugaredFirstFunc->getResult());
402411
}
@@ -558,6 +567,10 @@ class TypeMatcher {
558567
return MatchVisitor(*this).visit(first->getCanonicalType(), second,
559568
first);
560569
}
570+
571+
bool considerThrownErrorTypes(Type errorType1, Type errorType2) const {
572+
return false;
573+
}
561574
};
562575

563576
} // end namespace swift

include/swift/AST/Types.h

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3391,8 +3391,15 @@ class AnyFunctionType : public TypeBase {
33913391
///
33923392
/// Functions with untyped throws will produce "any Error", functions that
33933393
/// cannot throw or are specified to throw "Never" will return llvm::None.
3394-
llvm::Optional<Type> getEffectiveThrownInterfaceType() const;
3395-
3394+
llvm::Optional<Type> getEffectiveThrownErrorType() const;
3395+
3396+
/// Retrieve the "effective" thrown interface type, or `Never` if
3397+
/// this function cannot throw.
3398+
///
3399+
/// Functions with untyped throws will produce `any Error`, functions that
3400+
/// cannot throw or are specified to throw `Never` will return `Never`.
3401+
Type getEffectiveThrownErrorTypeOrNever() const;
3402+
33963403
/// Returns true if the function type stores a Clang type that cannot
33973404
/// be derived from its Swift type. Returns false otherwise, including if
33983405
/// the function type is not @convention(c) or @convention(block).

lib/AST/Decl.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -953,15 +953,15 @@ Type AbstractFunctionDecl::getThrownInterfaceType() const {
953953
}
954954

955955
llvm::Optional<Type>
956-
AbstractFunctionDecl::getEffectiveThrownInterfaceType() const {
956+
AbstractFunctionDecl::getEffectiveThrownErrorType() const {
957957
Type interfaceType = getInterfaceType();
958958
if (hasImplicitSelfDecl()) {
959959
if (auto fnType = interfaceType->getAs<AnyFunctionType>())
960960
interfaceType = fnType->getResult();
961961
}
962962

963963
return interfaceType->castTo<AnyFunctionType>()
964-
->getEffectiveThrownInterfaceType();
964+
->getEffectiveThrownErrorType();
965965
}
966966

967967
Expr *AbstractFunctionDecl::getSingleExpressionBody() const {

lib/AST/Expr.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1937,7 +1937,7 @@ Type AbstractClosureExpr::getResultType(
19371937

19381938
llvm::Optional<Type> AbstractClosureExpr::getEffectiveThrownType() const {
19391939
return getType()->castTo<AnyFunctionType>()
1940-
->getEffectiveThrownInterfaceType();
1940+
->getEffectiveThrownErrorType();
19411941
}
19421942

19431943
bool AbstractClosureExpr::isBodyThrowing() const {

lib/AST/Type.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5419,7 +5419,7 @@ AnyFunctionType *AnyFunctionType::getWithoutThrowing() const {
54195419
return withExtInfo(info);
54205420
}
54215421

5422-
llvm::Optional<Type> AnyFunctionType::getEffectiveThrownInterfaceType() const {
5422+
llvm::Optional<Type> AnyFunctionType::getEffectiveThrownErrorType() const {
54235423
// A non-throwing function... has no thrown interface type.
54245424
if (!isThrowing())
54255425
return llvm::None;
@@ -5437,6 +5437,13 @@ llvm::Optional<Type> AnyFunctionType::getEffectiveThrownInterfaceType() const {
54375437
return thrownError;
54385438
}
54395439

5440+
Type AnyFunctionType::getEffectiveThrownErrorTypeOrNever() const {
5441+
if (auto thrown = getEffectiveThrownErrorType())
5442+
return *thrown;
5443+
5444+
return getASTContext().getNeverType();
5445+
}
5446+
54405447
llvm::Optional<TangentSpace>
54415448
TypeBase::getAutoDiffTangentSpace(LookupConformanceFn lookupConformance) {
54425449
assert(lookupConformance);

lib/SILGen/SILGenBackDeploy.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ void SILGenFunction::emitBackDeploymentThunk(SILDeclRef thunk) {
238238
}
239239

240240
prepareEpilog(getResultInterfaceType(AFD),
241-
AFD->getEffectiveThrownInterfaceType(),
241+
AFD->getEffectiveThrownErrorType(),
242242
CleanupLocation(AFD));
243243

244244
SILBasicBlock *availableBB = createBasicBlock("availableBB");

lib/SILGen/SILGenConstructor.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -680,7 +680,7 @@ void SILGenFunction::emitValueConstructor(ConstructorDecl *ctor) {
680680
// Create a basic block to jump to for the implicit 'self' return.
681681
// We won't emit this until after we've emitted the body.
682682
// The epilog takes a void return because the return of 'self' is implicit.
683-
prepareEpilog(llvm::None, ctor->getEffectiveThrownInterfaceType(),
683+
prepareEpilog(llvm::None, ctor->getEffectiveThrownErrorType(),
684684
CleanupLocation(ctor));
685685

686686
// If the constructor can fail, set up an alternative epilog for constructor
@@ -1185,7 +1185,7 @@ void SILGenFunction::emitClassConstructorInitializer(ConstructorDecl *ctor) {
11851185

11861186
// Create a basic block to jump to for the implicit 'self' return.
11871187
// We won't emit the block until after we've emitted the body.
1188-
prepareEpilog(llvm::None, ctor->getEffectiveThrownInterfaceType(),
1188+
prepareEpilog(llvm::None, ctor->getEffectiveThrownErrorType(),
11891189
CleanupLocation(endOfInitLoc));
11901190

11911191
auto resultType = ctor->mapTypeIntoContext(ctor->getResultInterfaceType());
@@ -1751,7 +1751,7 @@ void SILGenFunction::emitInitAccessor(AccessorDecl *accessor) {
17511751
}
17521752

17531753
prepareEpilog(accessor->getResultInterfaceType(),
1754-
accessor->getEffectiveThrownInterfaceType(),
1754+
accessor->getEffectiveThrownErrorType(),
17551755
CleanupLocation(accessor));
17561756

17571757
emitProfilerIncrement(accessor->getTypecheckedBody());

lib/SILGen/SILGenFunction.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1051,7 +1051,7 @@ void SILGenFunction::emitFunction(FuncDecl *fd) {
10511051
emitDistributedActorFactory(fd);
10521052
} else {
10531053
prepareEpilog(fd->getResultInterfaceType(),
1054-
fd->getEffectiveThrownInterfaceType(), CleanupLocation(fd));
1054+
fd->getEffectiveThrownErrorType(), CleanupLocation(fd));
10551055

10561056
if (fd->requiresUnavailableDeclABICompatibilityStubs())
10571057
emitApplyOfUnavailableCodeReached();

lib/Sema/CSSimplify.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2957,8 +2957,8 @@ matchFunctionThrowing(ConstraintSystem &cs,
29572957
// that throws error type E2 when E1 is a subtype of E2. For the purpose
29582958
// of this comparison, a non-throwing function has thrown error type 'Never',
29592959
// and an untyped throwing function has thrown error type 'any Error'.
2960-
Type thrownError1 = getEffectiveThrownErrorTypeOrNever(func1);
2961-
Type thrownError2 = getEffectiveThrownErrorTypeOrNever(func2);
2960+
Type thrownError1 = func1->getEffectiveThrownErrorTypeOrNever();
2961+
Type thrownError2 = func2->getEffectiveThrownErrorTypeOrNever();
29622962
if (!thrownError1 || !thrownError2)
29632963
return cs.getTypeMatchSuccess();
29642964

lib/Sema/TypeCheckEffects.cpp

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -762,7 +762,7 @@ class Classification {
762762

763763
if (considerThrows) {
764764
if (auto thrownInterfaceType =
765-
func->getEffectiveThrownInterfaceType()) {
765+
func->getEffectiveThrownErrorType()) {
766766
Type thrownType =
767767
thrownInterfaceType->subst(declRef.getSubstitutions());
768768
result.merge(Classification::forThrows(thrownType,
@@ -1660,7 +1660,7 @@ class ApplyClassifier {
16601660
return Classification();
16611661

16621662
case EffectKind::Throws:
1663-
if (auto thrownError = fnType->getEffectiveThrownInterfaceType())
1663+
if (auto thrownError = fnType->getEffectiveThrownErrorType())
16641664
return Classification::forThrows(*thrownError, conditional, reason);
16651665

16661666
return Classification();
@@ -3305,13 +3305,6 @@ Type TypeChecker::errorUnion(Type type1, Type type2) {
33053305
return type1->getASTContext().getErrorExistentialType();
33063306
}
33073307

3308-
Type swift::getEffectiveThrownErrorTypeOrNever(AnyFunctionType *func) {
3309-
if (auto thrownError = func->getEffectiveThrownInterfaceType())
3310-
return *thrownError;
3311-
3312-
return func->getASTContext().getNeverType();
3313-
}
3314-
33153308
namespace {
33163309

33173310
/// Classifies a thrown error kind as Never, a specific type, or 'any Error'.

lib/Sema/TypeCheckEffects.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,6 @@
2222

2323
namespace swift {
2424

25-
/// Retrieve the effective thrown error type for the given function type, which
26-
/// is the thrown error type (is specified), `any Error` if untyped throwing, or
27-
/// `Never` if non-throwing.
28-
Type getEffectiveThrownErrorTypeOrNever(AnyFunctionType *func);
29-
3025
/// Classifies the result of a subtyping comparison between two thrown error
3126
/// types.
3227
enum class ThrownErrorSubtyping {

lib/Sema/TypeCheckProtocol.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -861,8 +861,8 @@ RequirementMatch swift::matchWitness(
861861
if (!reqThrownError) {
862862
// Save the thrown error types of the requirement and witness so we
863863
// can check them later.
864-
reqThrownError = getEffectiveThrownErrorTypeOrNever(reqFnType);
865-
witnessThrownError = getEffectiveThrownErrorTypeOrNever(witnessFnType);
864+
reqThrownError = reqFnType->getEffectiveThrownErrorTypeOrNever();
865+
witnessThrownError = witnessFnType->getEffectiveThrownErrorTypeOrNever();
866866
}
867867
}
868868
} else {
@@ -901,13 +901,19 @@ RequirementMatch swift::matchWitness(
901901
return RequirementMatch(witness, MatchKind::ThrowsConflict);
902902

903903
case ThrownErrorSubtyping::ExactMatch:
904-
case ThrownErrorSubtyping::Subtype:
905904
// All is well.
906905
break;
907906

907+
case ThrownErrorSubtyping::Subtype:
908+
// If there were no type parameters, we're done.
909+
if (!reqThrownError->hasTypeParameter())
910+
break;
911+
912+
LLVM_FALLTHROUGH;
913+
908914
case ThrownErrorSubtyping::Dependent:
909915
// We need to perform type matching
910-
if (auto result = matchTypes(witnessThrownError, reqThrownError)) {
916+
if (auto result = matchTypes(reqThrownError, witnessThrownError)) {
911917
return std::move(result.value());
912918
}
913919
break;

lib/Sema/TypeCheckProtocolInference.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -751,6 +751,10 @@ AssociatedTypeInference::inferTypeWitnessesViaValueWitness(ValueDecl *req,
751751
TypeBase *secondType, Type sugaredFirstType) {
752752
return true;
753753
}
754+
755+
bool considerThrownErrorTypes(Type errorType1, Type errorType2) const {
756+
return errorType1->hasTypeParameter();
757+
}
754758
};
755759

756760
// Match a requirement and witness type.

test/decl/protocol/conforms/typed_throws.swift

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,34 @@ struct ConformingToVeryThrowing: VeryThrowing {
3838
// FIXME: Diagnostic above should be better
3939
var prop4: Int { get throws(SubError) { 0 } }
4040
}
41+
42+
// Associated type inference.
43+
protocol FailureAssociatedType {
44+
associatedtype Failure: Error
45+
46+
func f() throws(Failure)
47+
}
48+
49+
struct S1: FailureAssociatedType {
50+
func f() throws(MyError) { }
51+
}
52+
53+
struct S2: FailureAssociatedType {
54+
func f() throws { }
55+
}
56+
57+
struct S3: FailureAssociatedType {
58+
func f() { }
59+
}
60+
61+
func testAssociatedTypes() {
62+
let _ = S1.Failure() // expected-error{{'S1.Failure' (aka 'MyError') cannot be constructed because it has no accessible initializers}}
63+
let _ = S2.Failure() // expected-error{{'S2.Failure' (aka 'any Error') cannot be constructed because it has no accessible initializers}}
64+
let _: Int = S3.Failure() // expected-error{{cannot convert value of type 'S3.Failure' (aka 'Never') to specified type 'Int'}}
65+
// expected-error@-1{{missing argument for parameter 'from' in call}}
66+
}
67+
68+
// Make sure we can throw the generic failure type.
69+
func assocFailureType<T: FailureAssociatedType>(_ value: T, _ error: T.Failure) throws(T.Failure) {
70+
throw error
71+
}

0 commit comments

Comments
 (0)