Skip to content

Commit 86cffd7

Browse files
authored
Merge pull request #32578 from slavapestov/derive-type-witness-before-inference
Try to derive a type witness in a known conformance before attempting associated type inference
2 parents 0816167 + dfbb958 commit 86cffd7

8 files changed

+76
-81
lines changed

lib/Sema/DerivedConformanceDifferentiable.cpp

Lines changed: 14 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -729,34 +729,6 @@ getOrSynthesizeTangentVectorStruct(DerivedConformance &derived, Identifier id) {
729729
return structDecl;
730730
}
731731

732-
/// Add a typealias declaration with the given name and underlying target
733-
/// struct type to the given source nominal declaration context.
734-
static void addAssociatedTypeAliasDecl(Identifier name, DeclContext *sourceDC,
735-
StructDecl *target,
736-
ASTContext &Context) {
737-
auto *nominal = sourceDC->getSelfNominalTypeDecl();
738-
assert(nominal && "Expected `DeclContext` to be a nominal type");
739-
auto lookup = nominal->lookupDirect(name);
740-
assert(lookup.size() < 2 &&
741-
"Expected at most one associated type named member");
742-
// If implicit type declaration with the given name already exists in source
743-
// struct, return it.
744-
if (lookup.size() == 1) {
745-
auto existingTypeDecl = dyn_cast<TypeDecl>(lookup.front());
746-
assert(existingTypeDecl && existingTypeDecl->isImplicit() &&
747-
"Expected lookup result to be an implicit type declaration");
748-
return;
749-
}
750-
// Otherwise, create a new typealias.
751-
auto *aliasDecl = new (Context)
752-
TypeAliasDecl(SourceLoc(), SourceLoc(), name, SourceLoc(), {}, sourceDC);
753-
aliasDecl->setUnderlyingType(target->getDeclaredInterfaceType());
754-
aliasDecl->setImplicit();
755-
aliasDecl->setGenericSignature(sourceDC->getGenericSignatureOfContext());
756-
cast<IterableDeclContext>(sourceDC->getAsDecl())->addMember(aliasDecl);
757-
aliasDecl->copyFormalAccessFrom(nominal, /*sourceIsParentContext*/ true);
758-
};
759-
760732
/// Diagnose stored properties in the nominal that do not have an explicit
761733
/// `@noDerivative` attribute, but either:
762734
/// - Do not conform to `Differentiable`.
@@ -842,7 +814,7 @@ static void checkAndDiagnoseImplicitNoDerivative(ASTContext &Context,
842814
}
843815

844816
/// Get or synthesize `TangentVector` struct type.
845-
static Type
817+
static std::pair<Type, TypeDecl *>
846818
getOrSynthesizeTangentVectorStructType(DerivedConformance &derived) {
847819
auto *parentDC = derived.getConformanceContext();
848820
auto *nominal = derived.Nominal;
@@ -852,28 +824,28 @@ getOrSynthesizeTangentVectorStructType(DerivedConformance &derived) {
852824
auto *tangentStruct =
853825
getOrSynthesizeTangentVectorStruct(derived, C.Id_TangentVector);
854826
if (!tangentStruct)
855-
return nullptr;
827+
return std::make_pair(nullptr, nullptr);
828+
856829
// Check and emit warnings for implicit `@noDerivative` members.
857830
checkAndDiagnoseImplicitNoDerivative(C, nominal, parentDC);
858-
// Add `TangentVector` typealias for `TangentVector` struct.
859-
addAssociatedTypeAliasDecl(C.Id_TangentVector, tangentStruct, tangentStruct,
860-
C);
861831

862832
// Return the `TangentVector` struct type.
863-
return parentDC->mapTypeIntoContext(
864-
tangentStruct->getDeclaredInterfaceType());
833+
return std::make_pair(
834+
parentDC->mapTypeIntoContext(
835+
tangentStruct->getDeclaredInterfaceType()),
836+
tangentStruct);
865837
}
866838

867839
/// Synthesize the `TangentVector` struct type.
868-
static Type
840+
static std::pair<Type, TypeDecl *>
869841
deriveDifferentiable_TangentVectorStruct(DerivedConformance &derived) {
870842
auto *parentDC = derived.getConformanceContext();
871843
auto *nominal = derived.Nominal;
872844

873845
// If nominal type can derive `TangentVector` as the contextual `Self` type,
874846
// return it.
875847
if (canDeriveTangentVectorAsSelf(nominal, parentDC))
876-
return parentDC->getSelfTypeInContext();
848+
return std::make_pair(parentDC->getSelfTypeInContext(), nullptr);
877849

878850
// Otherwise, get or synthesize `TangentVector` struct type.
879851
return getOrSynthesizeTangentVectorStructType(derived);
@@ -914,16 +886,17 @@ ValueDecl *DerivedConformance::deriveDifferentiable(ValueDecl *requirement) {
914886
return nullptr;
915887
}
916888

917-
Type DerivedConformance::deriveDifferentiable(AssociatedTypeDecl *requirement) {
889+
std::pair<Type, TypeDecl *>
890+
DerivedConformance::deriveDifferentiable(AssociatedTypeDecl *requirement) {
918891
// Diagnose unknown requirements.
919892
if (requirement->getBaseName() != Context.Id_TangentVector) {
920893
Context.Diags.diagnose(requirement->getLoc(),
921894
diag::broken_differentiable_requirement);
922-
return nullptr;
895+
return std::make_pair(nullptr, nullptr);
923896
}
924897
// Diagnose conformances in disallowed contexts.
925898
if (checkAndDiagnoseDisallowedContext(requirement))
926-
return nullptr;
899+
return std::make_pair(nullptr, nullptr);
927900

928901
// Start an error diagnostic before attempting derivation.
929902
// If derivation succeeds, cancel the diagnostic.
@@ -939,5 +912,5 @@ Type DerivedConformance::deriveDifferentiable(AssociatedTypeDecl *requirement) {
939912
}
940913

941914
// Otherwise, return nullptr.
942-
return nullptr;
915+
return std::make_pair(nullptr, nullptr);
943916
}

lib/Sema/DerivedConformances.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,8 @@ class DerivedConformance {
122122
/// Derive a Differentiable type witness for a nominal type.
123123
///
124124
/// \returns the derived member, which will also be added to the type.
125-
Type deriveDifferentiable(AssociatedTypeDecl *assocType);
125+
std::pair<Type, TypeDecl *>
126+
deriveDifferentiable(AssociatedTypeDecl *assocType);
126127

127128
/// Derive a CaseIterable requirement for an enum if it has no associated
128129
/// values for any of its cases.

lib/Sema/TypeCheckProtocol.cpp

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5579,29 +5579,30 @@ ValueDecl *TypeChecker::deriveProtocolRequirement(DeclContext *DC,
55795579
llvm_unreachable("unknown derivable protocol kind");
55805580
}
55815581

5582-
Type TypeChecker::deriveTypeWitness(DeclContext *DC,
5583-
NominalTypeDecl *TypeDecl,
5584-
AssociatedTypeDecl *AssocType) {
5582+
std::pair<Type, TypeDecl *>
5583+
TypeChecker::deriveTypeWitness(DeclContext *DC,
5584+
NominalTypeDecl *TypeDecl,
5585+
AssociatedTypeDecl *AssocType) {
55855586
auto *protocol = cast<ProtocolDecl>(AssocType->getDeclContext());
55865587

55875588
auto knownKind = protocol->getKnownProtocolKind();
55885589

55895590
if (!knownKind)
5590-
return nullptr;
5591+
return std::make_pair(nullptr, nullptr);
55915592

55925593
auto Decl = DC->getInnermostDeclarationDeclContext();
55935594

55945595
DerivedConformance derived(TypeDecl->getASTContext(), Decl, TypeDecl,
55955596
protocol);
55965597
switch (*knownKind) {
55975598
case KnownProtocolKind::RawRepresentable:
5598-
return derived.deriveRawRepresentable(AssocType);
5599+
return std::make_pair(derived.deriveRawRepresentable(AssocType), nullptr);
55995600
case KnownProtocolKind::CaseIterable:
5600-
return derived.deriveCaseIterable(AssocType);
5601+
return std::make_pair(derived.deriveCaseIterable(AssocType), nullptr);
56015602
case KnownProtocolKind::Differentiable:
56025603
return derived.deriveDifferentiable(AssocType);
56035604
default:
5604-
return nullptr;
5605+
return std::make_pair(nullptr, nullptr);
56055606
}
56065607
}
56075608

lib/Sema/TypeCheckProtocol.h

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -822,15 +822,13 @@ class AssociatedTypeInference {
822822

823823
/// Compute the "derived" type witness for an associated type that is
824824
/// known to the compiler.
825-
Type computeDerivedTypeWitness(AssociatedTypeDecl *assocType);
825+
std::pair<Type, TypeDecl *>
826+
computeDerivedTypeWitness(AssociatedTypeDecl *assocType);
826827

827828
/// Compute a type witness without using a specific potential witness,
828829
/// e.g., using a fixed type (from a refined protocol), default type
829830
/// on an associated type, or deriving the type.
830-
///
831-
/// \param allowDerived Whether to allow "derived" type witnesses.
832-
Type computeAbstractTypeWitness(AssociatedTypeDecl *assocType,
833-
bool allowDerived);
831+
Type computeAbstractTypeWitness(AssociatedTypeDecl *assocType);
834832

835833
/// Substitute the current type witnesses into the given interface type.
836834
Type substCurrentTypeWitnesses(Type type);

lib/Sema/TypeCheckProtocolInference.cpp

Lines changed: 29 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -868,38 +868,37 @@ Type AssociatedTypeInference::computeDefaultTypeWitness(
868868
return defaultType;
869869
}
870870

871-
Type AssociatedTypeInference::computeDerivedTypeWitness(
871+
std::pair<Type, TypeDecl *>
872+
AssociatedTypeInference::computeDerivedTypeWitness(
872873
AssociatedTypeDecl *assocType) {
873874
if (adoptee->hasError())
874-
return Type();
875+
return std::make_pair(Type(), nullptr);
875876

876877
// Can we derive conformances for this protocol and adoptee?
877878
NominalTypeDecl *derivingTypeDecl = adoptee->getAnyNominal();
878879
if (!DerivedConformance::derivesProtocolConformance(dc, derivingTypeDecl,
879880
proto))
880-
return Type();
881+
return std::make_pair(Type(), nullptr);
881882

882883
// Try to derive the type witness.
883-
Type derivedType =
884-
TypeChecker::deriveTypeWitness(dc, derivingTypeDecl, assocType);
885-
if (!derivedType)
886-
return Type();
884+
auto result = TypeChecker::deriveTypeWitness(dc, derivingTypeDecl, assocType);
885+
if (!result.first)
886+
return std::make_pair(Type(), nullptr);
887887

888-
// Make sure that the derived type is sane.
889-
if (checkTypeWitness(derivedType, assocType, conformance)) {
888+
// Make sure that the derived type satisfies requirements.
889+
if (checkTypeWitness(result.first, assocType, conformance)) {
890890
/// FIXME: Diagnose based on this.
891891
failedDerivedAssocType = assocType;
892-
failedDerivedWitness = derivedType;
893-
return Type();
892+
failedDerivedWitness = result.first;
893+
return std::make_pair(Type(), nullptr);
894894
}
895895

896-
return derivedType;
896+
return result;
897897
}
898898

899899
Type
900900
AssociatedTypeInference::computeAbstractTypeWitness(
901-
AssociatedTypeDecl *assocType,
902-
bool allowDerived) {
901+
AssociatedTypeDecl *assocType) {
903902
// We don't have a type witness for this associated type, so go
904903
// looking for more options.
905904
if (Type concreteType = computeFixedTypeWitness(assocType))
@@ -909,12 +908,6 @@ AssociatedTypeInference::computeAbstractTypeWitness(
909908
if (Type defaultType = computeDefaultTypeWitness(assocType))
910909
return defaultType;
911910

912-
// If we can derive a type witness, do so.
913-
if (allowDerived) {
914-
if (Type derivedType = computeDerivedTypeWitness(assocType))
915-
return derivedType;
916-
}
917-
918911
// If there is a generic parameter of the named type, use that.
919912
if (auto genericSig = dc->getGenericSignatureOfContext()) {
920913
for (auto gp : genericSig->getInnermostGenericParams()) {
@@ -1197,8 +1190,7 @@ void AssociatedTypeInference::findSolutionsRec(
11971190

11981191
// Try to compute the type without the aid of a specific potential
11991192
// witness.
1200-
if (Type type = computeAbstractTypeWitness(assocType,
1201-
/*allowDerived=*/true)) {
1193+
if (Type type = computeAbstractTypeWitness(assocType)) {
12021194
if (type->hasError()) {
12031195
recordMissing();
12041196
return;
@@ -1880,10 +1872,23 @@ auto AssociatedTypeInference::solve(ConformanceChecker &checker)
18801872
continue;
18811873

18821874
case ResolveWitnessResult::Missing:
1883-
// Note that we haven't resolved this associated type yet.
1884-
unresolvedAssocTypes.insert(assocType);
1875+
// We did not find the witness via name lookup. Try to derive
1876+
// it below.
18851877
break;
18861878
}
1879+
1880+
// Finally, try to derive the witness if we know how.
1881+
auto derivedType = computeDerivedTypeWitness(assocType);
1882+
if (derivedType.first) {
1883+
checker.recordTypeWitness(assocType,
1884+
derivedType.first->mapTypeOutOfContext(),
1885+
derivedType.second);
1886+
continue;
1887+
}
1888+
1889+
// We failed to derive the witness. We're going to go on to try
1890+
// to infer it from potential value witnesses next.
1891+
unresolvedAssocTypes.insert(assocType);
18871892
}
18881893

18891894
// Result variable to use for returns so that we get NRVO.

lib/Sema/TypeChecker.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -911,8 +911,9 @@ ValueDecl *deriveProtocolRequirement(DeclContext *DC,
911911
/// Derive an implicit type witness for the given associated type in
912912
/// the conformance of the given nominal type to some known
913913
/// protocol.
914-
Type deriveTypeWitness(DeclContext *DC, NominalTypeDecl *nominal,
915-
AssociatedTypeDecl *assocType);
914+
std::pair<Type, TypeDecl *>
915+
deriveTypeWitness(DeclContext *DC, NominalTypeDecl *nominal,
916+
AssociatedTypeDecl *assocType);
916917

917918
/// \name Name lookup
918919
///

test/Sema/enum_raw_representable.swift

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ var doubles: [Double] = serialize([Bar.a, .b, .c])
4646
var foos: [Foo] = deserialize([1, 2, 3])
4747
var bars: [Bar] = deserialize([1.2, 3.4, 5.6])
4848

49-
// Infer RawValue from witnesses.
49+
// We reject enums where the raw type stated in the inheritance clause does not
50+
// match the types of the witnesses.
5051
enum Color : Int {
5152
case red
5253
case blue
@@ -56,11 +57,13 @@ enum Color : Int {
5657
}
5758

5859
var rawValue: Double {
60+
// expected-error@-1 {{invalid redeclaration of synthesized implementation for protocol requirement 'rawValue'}}
5961
return 1.0
6062
}
6163
}
6264

6365
var colorRaw: Color.RawValue = 7.5
66+
// expected-error@-1 {{cannot convert value of type 'Double' to specified type 'Color.RawValue' (aka 'Int')}}
6467

6568
// Mismatched case types
6669

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
// RUN: %target-typecheck-verify-swift
2+
3+
// This used to fail with "reference to invalid associated type 'RawValue' of type 'E'"
4+
_ = E(rawValue: 123)
5+
6+
enum E : Int {
7+
case a = 123
8+
9+
init?(rawValue: RawValue) {
10+
self = .a
11+
}
12+
}
13+

0 commit comments

Comments
 (0)