Skip to content

Commit fb1be04

Browse files
dan-zhengrxwei
authored andcommitted
[Sema] [AutoDiff] Modify/fix Differentiable derived conformances. (#21871)
- Fix crasher regarding bad interaction between AdditiveArithmetic and Differentiable derived conformances, where both code paths attempt to synthesize memberwise initializers. - Conform synthesized `AllDifferentiableVariables` struct to `KeyPathIterable` if parent conforms to `KeyPathIterable`. - This is important for key-path based optimizer definitions. - Modify associated struct synthesis to check whether any `@noDerivative` stored properties exist. - If any `@noDerivative` stored properties exist, `AllDifferentiableVariables` struct must be synthesized. - Refactor derived conformances code, add tests.
1 parent 2d20e63 commit fb1be04

File tree

6 files changed

+244
-172
lines changed

6 files changed

+244
-172
lines changed

include/swift/AST/Decl.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3231,6 +3231,10 @@ class NominalTypeDecl : public GenericTypeDecl, public IterableDeclContext {
32313231
/// declared with @TFParameter).
32323232
void getAllTFParameters(SmallVectorImpl<VarDecl *> &result) const;
32333233

3234+
// SWIFT_ENABLE_TENSORFLOW
3235+
/// Get the memberwise initializer of the nominal type, if it exists.
3236+
ConstructorDecl *getMemberwiseInitializer();
3237+
32343238
private:
32353239
/// Predicate used to filter StoredPropertyRange.
32363240
struct ToStoredProperty {

lib/AST/Decl.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3025,6 +3025,29 @@ NominalTypeDecl::getAllTFParameters(SmallVectorImpl<VarDecl *> &result) const {
30253025
result.push_back(member);
30263026
}
30273027

3028+
// SWIFT_ENABLE_TENSORFLOW
3029+
ConstructorDecl *NominalTypeDecl::getMemberwiseInitializer() {
3030+
ConstructorDecl *memberwiseInitDecl = nullptr;
3031+
auto ctorDecls = lookupDirect(DeclBaseName::createConstructor());
3032+
for (auto decl : ctorDecls) {
3033+
auto ctorDecl = dyn_cast<ConstructorDecl>(decl);
3034+
if (!ctorDecl)
3035+
continue;
3036+
// Continue if:
3037+
// - Constructor is not a memberwise initializer.
3038+
// - Constructor is implicit and takes no arguments, and nominal has no
3039+
// stored properties. This is ad-hoc and accepts empty struct
3040+
// constructors generated via `TypeChecker::defineDefaultConstructor`.
3041+
if (!ctorDecl->isMemberwiseInitializer() &&
3042+
!(getStoredProperties().empty() && ctorDecl->isImplicit() &&
3043+
ctorDecl->getParameters()->size() == 0))
3044+
continue;
3045+
assert(!memberwiseInitDecl && "Memberwise initializer already found");
3046+
memberwiseInitDecl = ctorDecl;
3047+
}
3048+
return memberwiseInitDecl;
3049+
}
3050+
30283051
GenericTypeDecl::GenericTypeDecl(DeclKind K, DeclContext *DC,
30293052
Identifier name, SourceLoc nameLoc,
30303053
MutableArrayRef<TypeLoc> inherited,

lib/Sema/DerivedConformanceAdditiveArithmeticVectorNumeric.cpp

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -76,22 +76,6 @@ static ValueDecl *getProtocolRequirement(ProtocolDecl *proto, Identifier name) {
7676
return lookup[0];
7777
}
7878

79-
// Get memberwise initializer for a nominal type.
80-
static ConstructorDecl *getMemberwiseInitializer(NominalTypeDecl *nominal) {
81-
ConstructorDecl *memberwiseInitDecl = nullptr;
82-
for (auto member : nominal->getMembers()) {
83-
// Find memberwise initializer.
84-
if (!memberwiseInitDecl) {
85-
auto initDecl = dyn_cast<ConstructorDecl>(member);
86-
if (!initDecl || !initDecl->isMemberwiseInitializer())
87-
continue;
88-
assert(!memberwiseInitDecl && "Memberwise initializer already found");
89-
memberwiseInitDecl = initDecl;
90-
}
91-
}
92-
return memberwiseInitDecl;
93-
}
94-
9579
// Return the `Scalar` associated type for a ValueDecl if it conforms to
9680
// `VectorNumeric` in the given context.
9781
// If the decl does not conform to `VectorNumeric`, return a null `Type`.
@@ -182,7 +166,7 @@ static void deriveBodyMathOperator(AbstractFunctionDecl *funcDecl,
182166
auto &C = nominal->getASTContext();
183167

184168
// Create memberwise initializer: `Nominal.init(...)`.
185-
auto *memberwiseInitDecl = getMemberwiseInitializer(nominal);
169+
auto *memberwiseInitDecl = nominal->getMemberwiseInitializer();
186170
auto *initDRE =
187171
new (C) DeclRefExpr(memberwiseInitDecl, DeclNameLoc(), /*Implicit*/ true);
188172
initDRE->setFunctionRefKind(FunctionRefKind::SingleApply);
@@ -361,7 +345,7 @@ static void deriveBodyAdditiveArithmetic_zero(AbstractFunctionDecl *funcDecl) {
361345
auto *nominal = funcDecl->getDeclContext()->getSelfNominalTypeDecl();
362346
auto &C = nominal->getASTContext();
363347

364-
auto *memberwiseInitDecl = getMemberwiseInitializer(nominal);
348+
auto *memberwiseInitDecl = nominal->getMemberwiseInitializer();
365349
auto *initDRE =
366350
new (C) DeclRefExpr(memberwiseInitDecl, DeclNameLoc(), /*Implicit*/ true);
367351
initDRE->setFunctionRefKind(FunctionRefKind::SingleApply);
@@ -416,7 +400,7 @@ static ValueDecl *deriveAdditiveArithmetic_zero(DerivedConformance &derived) {
416400
// The implicit memberwise constructor must be explicitly created so that it
417401
// can called when synthesizing the `zero` property getter. Normally, the
418402
// memberwise constructor is synthesized during SILGen, which is too late.
419-
if (!getMemberwiseInitializer(nominal)) {
403+
if (!nominal->getMemberwiseInitializer()) {
420404
auto *initDecl = createImplicitConstructor(
421405
TC, nominal, ImplicitConstructorKind::Memberwise);
422406
nominal->addMember(initDecl);

0 commit comments

Comments
 (0)