Skip to content

Commit 8f3d160

Browse files
dan-zhengrxwei
authored andcommitted
[AutoDiff] [API] Continue deriving conformances for Differentiable. (#21646)
* [API] Continue deriving conformances for `Differentiable`. Continue implementing `Differentiable` derived conformances for struct types. A follow-up to #21580. This patch introduces support for: - Synthesizing member `TangentVector` and `CotangentVector` structs. This is necessary when not all members have `Self == TangentVector == CotangentVector`. - Marking synthesized vector space structs and typealiases with `@_fieldwiseProductSpace`. - Making synthesis consistent with recent changes to `Differentiable`. - Vector space associated types now conform to `AdditiveArithmetic`, not `VectorNumeric` (which is limited to a single `Scalar` type). - Use `@noDerivative` attribute to determine vector space struct members. A member in the original struct type marked with `@noDerivative` will not have corresponding members in vector space structs. - Enabliing derived conformances to `AdditiveArithmetic` and `VectorNumeric` for structs with no stored properties. - Many bug fixes regarding code synthesis, details omitted. There are some todos: - Fix SR-9595, which is related to mutually recursive associated types. The bug blocks `Differentiable` synthesis for structs whose members have generic types (e.g. `<T : Differentiable>`). Failing tests due to this bug are commented in `test/Sema/struct_differentiable.swift`. - Synthesize `@differentiable(wrt: (self))` on stored properties of `Differentiable`-conforming types. - It's possible to do this synthesis during the derived conformances code path, but the derived conformances code path is not triggered if a type manually defines all `Differentiable` requirements. - Add `Differentiable` derived conformances runtime tests. - Future: enable `Differentiable` derived conformances for enums. * [API] Conform synthesized vector space structs to `KeyPathIterable`. * Minor cleanup. * [API] Un-conform synthesized vector space structs from `KeyPathIterable`. Synthesizing `KeyPathIterable` for these structs is ad-hoc and limited in usefulness (not useful when `Self != TangentVector` or `Self != CotangentVector`). A more general solution involves providing a mapping between `Self`s keypaths and vector space structs' keypaths. * Fix test/AutoDiff/noderivative-attr.swift.
1 parent cfe0bc7 commit 8f3d160

7 files changed

+385
-194
lines changed

lib/Sema/DerivedConformanceAdditiveArithmeticVectorNumeric.cpp

Lines changed: 42 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -93,33 +93,35 @@ static ConstructorDecl *getMemberwiseInitializer(NominalTypeDecl *nominal) {
9393
}
9494

9595
// Return the `Scalar` associated type for a ValueDecl if it conforms to
96-
// `VectorNumeric`.
96+
// `VectorNumeric` in the given context.
9797
// If the decl does not conform to `VectorNumeric`, return a null `Type`.
98-
static Type getVectorNumericScalarAssocType(ValueDecl *decl) {
98+
static Type getVectorNumericScalarAssocType(VarDecl *decl, DeclContext *DC) {
9999
auto &C = decl->getASTContext();
100100
auto *vectorNumericProto = C.getProtocol(KnownProtocolKind::VectorNumeric);
101-
auto declType =
102-
decl->getDeclContext()->mapTypeIntoContext(decl->getInterfaceType());
103-
auto conf = TypeChecker::conformsToProtocol(declType, vectorNumericProto,
104-
decl->getDeclContext(),
101+
if (!decl->hasType())
102+
C.getLazyResolver()->resolveDeclSignature(decl);
103+
if (!decl->hasType())
104+
return Type();
105+
auto declType = decl->getType()->hasArchetype()
106+
? decl->getType()
107+
: DC->mapTypeIntoContext(decl->getType());
108+
auto conf = TypeChecker::conformsToProtocol(declType, vectorNumericProto, DC,
105109
ConformanceCheckFlags::Used);
106110
if (!conf)
107111
return Type();
108112
Type scalarType = ProtocolConformanceRef::getTypeWitnessByName(
109-
decl->getInterfaceType(), *conf, C.Id_Scalar, C.getLazyResolver());
113+
declType, *conf, C.Id_Scalar, C.getLazyResolver());
110114
assert(scalarType && "'Scalar' associated type not found");
111115
return scalarType;
112116
}
113117

114-
static Type deriveVectorNumeric_Scalar(NominalTypeDecl *nominal) {
115-
// Must be a struct type.
118+
static Type deriveVectorNumeric_Scalar(NominalTypeDecl *nominal,
119+
DeclContext *DC) {
120+
// Nominal type must be a struct. (Zero stored properties is okay.)
116121
auto *structDecl = dyn_cast<StructDecl>(nominal);
117122
auto &C = nominal->getASTContext();
118123
if (!structDecl)
119124
return Type();
120-
// Struct must have at least one stored property.
121-
if (structDecl->getStoredProperties().empty())
122-
return Type();
123125
// If all stored properties conform to `VectorNumeric` and have the same
124126
// `Scalar` associated type, return that `Scalar` associated type.
125127
// Otherwise, the `Scalar` type cannot be derived.
@@ -129,7 +131,7 @@ static Type deriveVectorNumeric_Scalar(NominalTypeDecl *nominal) {
129131
C.getLazyResolver()->resolveDeclSignature(member);
130132
if (!member->hasInterfaceType())
131133
return Type();
132-
auto scalarType = getVectorNumericScalarAssocType(member);
134+
auto scalarType = getVectorNumericScalarAssocType(member, DC);
133135
// If stored property does not conform to `VectorNumeric`, return null
134136
// `Type`.
135137
if (!scalarType)
@@ -146,31 +148,31 @@ static Type deriveVectorNumeric_Scalar(NominalTypeDecl *nominal) {
146148
return sameScalarType;
147149
}
148150

149-
bool DerivedConformance::canDeriveAdditiveArithmetic(NominalTypeDecl *nominal) {
150-
// Must be a struct type.
151+
bool DerivedConformance::canDeriveAdditiveArithmetic(NominalTypeDecl *nominal,
152+
DeclContext *DC) {
153+
// Nominal type must be a struct. (Zero stored properties is okay.)
151154
auto *structDecl = dyn_cast<StructDecl>(nominal);
152155
if (!structDecl)
153156
return false;
154-
// Struct must have at least one stored property.
155-
if (structDecl->getStoredProperties().empty())
156-
return false;
157157
// All stored properties must conform to `AdditiveArithmetic`.
158158
auto &C = nominal->getASTContext();
159159
auto *addArithProto = C.getProtocol(KnownProtocolKind::AdditiveArithmetic);
160160
return llvm::all_of(structDecl->getStoredProperties(), [&](VarDecl *v) {
161-
if (!v->hasInterfaceType())
161+
if (!v->getType())
162162
C.getLazyResolver()->resolveDeclSignature(v);
163-
if (!v->hasInterfaceType())
163+
if (!v->getType())
164164
return false;
165-
auto conf = TypeChecker::conformsToProtocol(v->getType(), addArithProto,
166-
v->getDeclContext(),
167-
ConformanceCheckFlags::Used);
168-
return (bool)conf;
165+
auto declType = v->getType()->hasArchetype()
166+
? v->getType()
167+
: DC->mapTypeIntoContext(v->getType());
168+
return (bool)TypeChecker::conformsToProtocol(declType, addArithProto, DC,
169+
ConformanceCheckFlags::Used);
169170
});
170171
}
171172

172-
bool DerivedConformance::canDeriveVectorNumeric(NominalTypeDecl *nominal) {
173-
return bool(deriveVectorNumeric_Scalar(nominal));
173+
bool DerivedConformance::canDeriveVectorNumeric(NominalTypeDecl *nominal,
174+
DeclContext *DC) {
175+
return bool(deriveVectorNumeric_Scalar(nominal, DC));
174176
}
175177

176178
// Synthesize body for the given math operator.
@@ -203,7 +205,11 @@ static void deriveBodyMathOperator(AbstractFunctionDecl *funcDecl,
203205
// Create expression combining lhs and rhs members using member operator.
204206
auto createMemberOpExpr = [&](VarDecl *member) -> Expr * {
205207
auto module = nominal->getModuleContext();
206-
auto confRef = module->lookupConformance(member->getType(), proto);
208+
auto memberType = member->getType()->hasArchetype()
209+
? member->getType()
210+
: nominal->mapTypeIntoContext(
211+
member->getType()->mapTypeOutOfContext());
212+
auto confRef = module->lookupConformance(memberType, proto);
207213
assert(confRef && "Member does not conform to math protocol");
208214

209215
// Get member type's math operator, e.g. `Member.+`.
@@ -213,8 +219,10 @@ static void deriveBodyMathOperator(AbstractFunctionDecl *funcDecl,
213219
// If conformance reference is concrete, then use concrete witness
214220
// declaration for the operator.
215221
if (confRef->isConcrete())
216-
memberOpDecl =
217-
confRef->getConcrete()->getWitnessDecl(operatorReq, nullptr);
222+
if (auto opDecl =
223+
confRef->getConcrete()->getWitnessDecl(operatorReq, nullptr))
224+
memberOpDecl = opDecl;
225+
assert(memberOpDecl && "Member operator declaration must exist");
218226
auto memberOpDRE =
219227
new (C) DeclRefExpr(memberOpDecl, DeclNameLoc(), /*Implicit*/ true);
220228
auto *memberTypeExpr = TypeExpr::createImplicit(member->getType(), C);
@@ -296,8 +304,9 @@ static ValueDecl *deriveMathOperator(DerivedConformance &derived,
296304
case Subtract:
297305
return std::make_pair(selfInterfaceType, selfInterfaceType);
298306
case ScalarMultiply:
299-
return std::make_pair(deriveVectorNumeric_Scalar(nominal),
300-
selfInterfaceType);
307+
return std::make_pair(
308+
deriveVectorNumeric_Scalar(nominal, parentDC)->mapTypeOutOfContext(),
309+
selfInterfaceType);
301310
}
302311
};
303312

@@ -376,7 +385,7 @@ static void deriveBodyAdditiveArithmetic_zero(AbstractFunctionDecl *funcDecl) {
376385
return new (C) MemberRefExpr(memberTypeExpr, SourceLoc(), zeroReq,
377386
DeclNameLoc(), /*Implicit*/ true);
378387
}
379-
// Otherwise,return reference to concrete witness declaration for `zero`.
388+
// Otherwise, return reference to concrete witness declaration for `zero`.
380389
auto conf = confRef->getConcrete();
381390
auto zeroDecl = conf->getWitnessDecl(zeroReq, nullptr);
382391
return new (C) MemberRefExpr(memberTypeExpr, SourceLoc(), zeroDecl,
@@ -462,8 +471,7 @@ ValueDecl *DerivedConformance::deriveVectorNumeric(ValueDecl *requirement) {
462471

463472
Type DerivedConformance::deriveVectorNumeric(AssociatedTypeDecl *requirement) {
464473
if (requirement->getBaseName() == TC.Context.Id_Scalar) {
465-
auto rawType = deriveVectorNumeric_Scalar(Nominal);
466-
return getConformanceContext()->mapTypeIntoContext(rawType);
474+
return deriveVectorNumeric_Scalar(Nominal, getConformanceContext());
467475
}
468476
TC.diagnose(requirement->getLoc(), diag::broken_vector_numeric_requirement);
469477
return nullptr;

0 commit comments

Comments
 (0)