@@ -148,10 +148,18 @@ bool DerivedConformance::canDeriveDifferentiable(NominalTypeDecl *nominal,
148
148
if (!structDecl)
149
149
return nullptr ;
150
150
// Valid candidate must either:
151
- // - Be implicit (previously synthesized).
152
- // - Equal nominal (and conform to `AdditiveArithmetic` if flag is true).
151
+ // 1. Be implicit (previously synthesized).
153
152
if (structDecl->isImplicit ())
154
153
return structDecl;
154
+ // 2. Equal nominal's implicit parent.
155
+ // This can occur during mutually recursive constraints. Example:
156
+ // `X == X.TangentVector, X.CotangentVector.CotangentVector == X`.
157
+ if (nominal->isImplicit () && structDecl == nominal->getDeclContext () &&
158
+ TypeChecker::conformsToProtocol (structDecl->getDeclaredInterfaceType (),
159
+ diffableProto, DC,
160
+ ConformanceCheckFlags::Used))
161
+ return structDecl;
162
+ // 3. Equal nominal (and conform to `AdditiveArithmetic` if flag is true).
155
163
if (structDecl == nominal) {
156
164
if (!checkAdditiveArithmetic)
157
165
return structDecl;
@@ -815,12 +823,12 @@ static void addAssociatedTypeAliasDecl(Identifier name,
815
823
auto lookup = nominal->lookupDirect (name);
816
824
assert (lookup.size () < 2 &&
817
825
" Expected at most one associated type named member" );
818
- // If implicit typealias with the given name already exists in source
826
+ // If implicit type declaration with the given name already exists in source
819
827
// struct, return it.
820
828
if (lookup.size () == 1 ) {
821
- auto existingAlias = dyn_cast<TypeAliasDecl >(lookup.front ());
822
- assert (existingAlias && existingAlias ->isImplicit () &&
823
- " Expected lookup result to be an implicit typealias " );
829
+ auto existingTypeDecl = dyn_cast<TypeDecl >(lookup.front ());
830
+ assert (existingTypeDecl && existingTypeDecl ->isImplicit () &&
831
+ " Expected lookup result to be an implicit type declaration " );
824
832
return ;
825
833
}
826
834
// Otherwise, create a new typealias.
@@ -898,31 +906,31 @@ getOrSynthesizeAssociatedStructType(DerivedConformance &derived,
898
906
auto *nominal = derived.Nominal ;
899
907
auto &C = nominal->getASTContext ();
900
908
901
- // Get or synthesize `TangentVector`, `CotangentVector`, and
902
- // `AllDifferentiableVariables` structs at once. Synthesizing all three
903
- // structs at once is necessary in order to correctly set their mutually
904
- // recursive associated types.
909
+ // Get or synthesize `AllDifferentiableVariables`, `TangentVector`, and
910
+ // `CotangentVector` structs at once. Synthesizing all three structs at once
911
+ // is necessary in order to correctly set their mutually recursive associated
912
+ // types.
913
+ auto allDiffableVarsStructSynthesis =
914
+ getOrSynthesizeSingleAssociatedStruct (derived,
915
+ C.Id_AllDifferentiableVariables );
916
+ auto *allDiffableVarsStruct = allDiffableVarsStructSynthesis.first ;
917
+ if (!allDiffableVarsStruct)
918
+ return nullptr ;
919
+ bool freshlySynthesized = allDiffableVarsStructSynthesis.second ;
920
+
905
921
auto tangentStructSynthesis =
906
922
getOrSynthesizeSingleAssociatedStruct (derived, C.Id_TangentVector );
907
923
auto *tangentStruct = tangentStructSynthesis.first ;
908
- bool freshlySynthesized = tangentStructSynthesis.second ;
909
924
if (!tangentStruct)
910
925
return nullptr ;
926
+ freshlySynthesized |= tangentStructSynthesis.second ;
911
927
912
928
auto cotangentStructSynthesis =
913
929
getOrSynthesizeSingleAssociatedStruct (derived, C.Id_CotangentVector );
914
930
auto *cotangentStruct = cotangentStructSynthesis.first ;
915
931
if (!cotangentStruct)
916
932
return nullptr ;
917
- assert (freshlySynthesized == cotangentStructSynthesis.second );
918
-
919
- auto allDiffableVarsStructSynthesis =
920
- getOrSynthesizeSingleAssociatedStruct (derived,
921
- C.Id_AllDifferentiableVariables );
922
- auto *allDiffableVarsStruct = allDiffableVarsStructSynthesis.first ;
923
- if (!allDiffableVarsStruct)
924
- return nullptr ;
925
- assert (freshlySynthesized == allDiffableVarsStructSynthesis.second );
933
+ freshlySynthesized |= cotangentStructSynthesis.second ;
926
934
927
935
// When all structs are freshly synthesized, we check emit warnings for
928
936
// implicit `@noDerivative` members. Checking for fresh synthesis is necessary
@@ -953,9 +961,9 @@ getOrSynthesizeAssociatedStructType(DerivedConformance &derived,
953
961
addAssociatedTypeAliasDecl (C.Id_AllDifferentiableVariables ,
954
962
cotangentStruct, cotangentStruct, TC);
955
963
964
+ TC.validateDecl (allDiffableVarsStruct);
956
965
TC.validateDecl (tangentStruct);
957
966
TC.validateDecl (cotangentStruct);
958
- TC.validateDecl (allDiffableVarsStruct);
959
967
960
968
// Sanity checks for synthesized structs.
961
969
assert (DerivedConformance::canDeriveAdditiveArithmetic (tangentStruct,
@@ -1082,7 +1090,7 @@ deriveDifferentiable_AssociatedStruct(DerivedConformance &derived,
1082
1090
if (allMembersAssocTypesEqualsSelf) {
1083
1091
auto allDiffableVarsStructSynthesis = getOrSynthesizeSingleAssociatedStruct (
1084
1092
derived, C.Id_AllDifferentiableVariables );
1085
- auto allDiffableVarsStruct = allDiffableVarsStructSynthesis.first ;
1093
+ auto * allDiffableVarsStruct = allDiffableVarsStructSynthesis.first ;
1086
1094
auto freshlySynthesized = allDiffableVarsStructSynthesis.second ;
1087
1095
// `AllDifferentiableVariables` must conform to `AdditiveArithmetic`.
1088
1096
// This should be guaranteed.
0 commit comments