Skip to content

Commit 38161f0

Browse files
authored
Merge pull request #69181 from DougGregor/infer-error-conformance-from-thrown-type
2 parents 3979d61 + 2d7eafd commit 38161f0

File tree

3 files changed

+57
-6
lines changed

3 files changed

+57
-6
lines changed

lib/AST/RequirementMachine/RequirementLowering.cpp

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -562,15 +562,17 @@ struct InferRequirementsWalker : public TypeWalker {
562562
// - `@differentiable(_linear)`: add
563563
// `T: Differentiable`, `T == T.TangentVector` requirements.
564564
if (auto *fnTy = ty->getAs<AnyFunctionType>()) {
565+
// Add a new conformance constraint for a fixed protocol.
566+
auto addConformanceConstraint = [&](Type type, ProtocolDecl *protocol) {
567+
Requirement req(RequirementKind::Conformance, type,
568+
protocol->getDeclaredInterfaceType());
569+
desugarRequirement(req, SourceLoc(), reqs, errors);
570+
};
571+
565572
auto &ctx = module->getASTContext();
566573
auto *differentiableProtocol =
567574
ctx.getProtocol(KnownProtocolKind::Differentiable);
568575
if (differentiableProtocol && fnTy->isDifferentiable()) {
569-
auto addConformanceConstraint = [&](Type type, ProtocolDecl *protocol) {
570-
Requirement req(RequirementKind::Conformance, type,
571-
protocol->getDeclaredInterfaceType());
572-
desugarRequirement(req, SourceLoc(), reqs, errors);
573-
};
574576
auto addSameTypeConstraint = [&](Type firstType,
575577
AssociatedTypeDecl *assocType) {
576578
auto secondType = assocType->getDeclaredInterfaceType()
@@ -596,6 +598,13 @@ struct InferRequirementsWalker : public TypeWalker {
596598
constrainParametersAndResult(fnTy->getDifferentiabilityKind() ==
597599
DifferentiabilityKind::Linear);
598600
}
601+
602+
// Infer that the thrown error type conforms to Error.
603+
if (auto thrownError = fnTy->getThrownError()) {
604+
if (auto errorProtocol = ctx.getErrorDecl()) {
605+
addConformanceConstraint(thrownError, errorProtocol);
606+
}
607+
}
599608
}
600609

601610
if (!ty->isSpecialized())

lib/Sema/TypeCheckGeneric.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -753,6 +753,28 @@ GenericSignatureRequest::evaluate(Evaluator &evaluator,
753753
inferenceSources.emplace_back(typeRepr, type);
754754
}
755755

756+
// Handle the thrown error type.
757+
auto effectiveFunc = func ? func
758+
: subscr ? subscr->getEffectfulGetAccessor()
759+
: nullptr;
760+
if (effectiveFunc) {
761+
if (auto thrownTypeRepr = effectiveFunc->getThrownTypeRepr()) {
762+
auto thrownOptions = baseOptions | TypeResolutionFlags::Direct;
763+
const auto thrownType = resolution.withOptions(thrownOptions)
764+
.resolveType(thrownTypeRepr);
765+
766+
// Add this type as an inference source.
767+
inferenceSources.emplace_back(thrownTypeRepr, thrownType);
768+
769+
// Add conformance of this type to the Error protocol.
770+
if (auto errorProtocol = ctx.getErrorDecl()) {
771+
extraReqs.push_back(
772+
Requirement(RequirementKind::Conformance, thrownType,
773+
errorProtocol->getDeclaredInterfaceType()));
774+
}
775+
}
776+
}
777+
756778
// Gather requirements from the result type.
757779
auto *resultTypeRepr = [&subscr, &func, &macro]() -> TypeRepr * {
758780
if (subscr) {

test/decl/func/typed_throws.swift

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ func testThrownMyErrorType() {
3333
func throwsGeneric<T: Error>(errorType: T.Type) throws(T) { }
3434

3535
func throwsBadGeneric<T>(errorType: T.Type) throws(T) { }
36-
// expected-error@-1{{thrown type 'T' does not conform to the 'Error' protocol}}
3736

3837
func throwsUnusedInSignature<T: Error>() throws(T) { }
3938
// expected-error@-1{{generic parameter 'T' is not used in function signature}}
@@ -103,3 +102,24 @@ func testMapArray(numbers: [Int]) {
103102
let _: Int = error // expected-error{{cannot convert value of type 'MyError' to specified type 'Int'}}
104103
}
105104
}
105+
106+
// Inference of Error conformance from the use of a generic parameter in typed
107+
// throws.
108+
func requiresError<E: Error>(_: E.Type) { }
109+
110+
func infersThrowing<E>(_ error: E.Type) throws(E) {
111+
requiresError(error)
112+
}
113+
114+
func infersThrowingNested<E>(_ body: () throws(E) -> Void) {
115+
requiresError(E.self)
116+
}
117+
118+
struct HasASubscript {
119+
subscript<E>(_: E.Type) -> Int {
120+
get throws(E) {
121+
requiresError(E.self)
122+
return 0
123+
}
124+
}
125+
}

0 commit comments

Comments
 (0)