Skip to content

Commit ea45763

Browse files
committed
Conform Differentiable assoc. types to ElementaryFunctions if possible.
Conform `Differentiable` synthesized associated types to `ElementaryFunctions` if possible. Similar to existing logic for `VectorProtocol`.
1 parent fefbf82 commit ea45763

File tree

4 files changed

+48
-12
lines changed

4 files changed

+48
-12
lines changed

lib/Sema/DerivedConformanceAdditiveArithmetic.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ static void deriveBodyMathOperator(AbstractFunctionDecl *funcDecl,
183183
memberOpExprs.push_back(createMemberOpExpr(member));
184184
memberNames.push_back(member->getName());
185185
}
186-
// Call memberwise initialier with member operator call expressions.
186+
// Call memberwise initializer with member operator call expressions.
187187
auto *callExpr =
188188
CallExpr::createImplicit(C, initExpr, memberOpExprs, memberNames);
189189
ASTNode returnStmt = new (C) ReturnStmt(SourceLoc(), callExpr, true);

lib/Sema/DerivedConformanceDifferentiable.cpp

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -568,12 +568,14 @@ getOrSynthesizeSingleAssociatedStruct(DerivedConformance &derived,
568568
auto diffableType = TypeLoc::withoutLoc(diffableProto->getDeclaredType());
569569
auto *addArithProto = C.getProtocol(KnownProtocolKind::AdditiveArithmetic);
570570
auto addArithType = TypeLoc::withoutLoc(addArithProto->getDeclaredType());
571+
auto *mathProto = C.getProtocol(KnownProtocolKind::ElementaryFunctions);
572+
auto mathType = TypeLoc::withoutLoc(mathProto->getDeclaredType());
571573
auto *vectorProto = C.getProtocol(KnownProtocolKind::VectorProtocol);
572574
auto vectorType = TypeLoc::withoutLoc(vectorProto->getDeclaredType());
573575
auto *kpIterableProto = C.getProtocol(KnownProtocolKind::KeyPathIterable);
574576
auto kpIterableType = TypeLoc::withoutLoc(kpIterableProto->getDeclaredType());
575577

576-
SmallVector<TypeLoc, 3> inherited {diffableType};
578+
SmallVector<TypeLoc, 3> inherited{diffableType};
577579

578580
// Cache original members and their associated types for later use.
579581
SmallVector<VarDecl *, 8> diffProperties;
@@ -589,9 +591,16 @@ getOrSynthesizeSingleAssociatedStruct(DerivedConformance &derived,
589591
bool canDeriveAdditiveArithmetic =
590592
llvm::all_of(diffProperties, [&](VarDecl *vd) {
591593
return TC.conformsToProtocol(getAssociatedType(vd, parentDC, id),
592-
addArithProto, parentDC,
593-
None);
594-
});
594+
addArithProto, parentDC, None);
595+
});
596+
597+
// Associated struct can derive `ElementaryFunctions` if the associated types
598+
// of all stored properties conform to `ElementaryFunctions`.
599+
bool canDeriveElementaryFunctions =
600+
llvm::all_of(diffProperties, [&](VarDecl *vd) {
601+
return TC.conformsToProtocol(getAssociatedType(vd, parentDC, id),
602+
mathProto, parentDC, None);
603+
});
595604

596605
// Associated struct can derive `VectorProtocol` if the associated types of
597606
// all members conform to `VectorProtocol` and share the same
@@ -625,6 +634,10 @@ getOrSynthesizeSingleAssociatedStruct(DerivedConformance &derived,
625634
None))
626635
inherited.push_back(kpIterableType);
627636
}
637+
// If all members conform to `ElementaryFunctions`, make the associated struct
638+
// conform to `ElementaryFunctions`.
639+
if (canDeriveElementaryFunctions)
640+
inherited.push_back(mathType);
628641
// If all members also conform to `VectorProtocol` with the same `Scalar`
629642
// type, make the associated struct conform to `VectorProtocol` instead of
630643
// just `AdditiveArithmetic`.

test/AutoDiff/derived_differentiable_properties.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ struct TestNoDerivative : Differentiable {
4040
// CHECK-AST: var w: Float
4141
// CHECK-AST: @noDerivative internal var technicallyDifferentiable: Float
4242
// CHECK-AST: internal init(w: Float, technicallyDifferentiable: Float)
43-
// CHECK-AST: internal struct AllDifferentiableVariables : Differentiable, AdditiveArithmetic, VectorProtocol
43+
// CHECK-AST: internal struct AllDifferentiableVariables : Differentiable, AdditiveArithmetic, ElementaryFunctions, VectorProtocol
4444
// CHECK-AST: internal typealias AllDifferentiableVariables = TestNoDerivative.AllDifferentiableVariables
4545
// CHECK-AST: internal typealias TangentVector = TestNoDerivative.AllDifferentiableVariables
4646
// CHECK-AST: internal typealias TangentVector = TestNoDerivative.AllDifferentiableVariables
@@ -54,7 +54,7 @@ struct TestKeyPathIterable : Differentiable, KeyPathIterable {
5454
// CHECK-AST: var w: Float
5555
// CHECK-AST: @noDerivative internal var technicallyDifferentiable: Float
5656
// CHECK-AST: internal init(w: Float, technicallyDifferentiable: Float)
57-
// CHECK-AST: internal struct AllDifferentiableVariables : Differentiable, AdditiveArithmetic, KeyPathIterable, VectorProtocol
57+
// CHECK-AST: internal struct AllDifferentiableVariables : Differentiable, AdditiveArithmetic, KeyPathIterable, ElementaryFunctions, VectorProtocol
5858
// CHECK-AST: internal typealias AllDifferentiableVariables = TestKeyPathIterable.AllDifferentiableVariables
5959
// CHECK-AST: internal typealias TangentVector = TestKeyPathIterable.AllDifferentiableVariables
6060
// CHECK-AST: internal typealias TangentVector = TestKeyPathIterable.AllDifferentiableVariables

test/Sema/struct_differentiable.swift

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,18 @@ func assertAllDifferentiableVariablesEqualsTangentVector<T>(_: T.Type)
88
// Verify that a type `T` conforms to `AdditiveArithmetic`.
99
func assertConformsToAdditiveArithmetic<T>(_: T.Type) where T : AdditiveArithmetic {}
1010

11+
// Verify that a type `T` conforms to `ElementaryFunctions`.
12+
func assertConformsToElementaryFunctions<T>(_: T.Type) where T : ElementaryFunctions {}
13+
1114
// Verify that a type `T` conforms to `VectorProtocol`.
1215
func assertConformsToVectorProtocol<T>(_: T.Type) where T : VectorProtocol {}
1316

1417
struct Empty : Differentiable {}
1518
func testEmpty() {
1619
assertConformsToAdditiveArithmetic(Empty.AllDifferentiableVariables.self)
20+
assertConformsToAdditiveArithmetic(Empty.TangentVector.self)
21+
assertConformsToElementaryFunctions(Empty.AllDifferentiableVariables.self)
22+
assertConformsToElementaryFunctions(Empty.TangentVector.self)
1723
}
1824

1925
// Test interaction with `AdditiveArithmetic` derived conformances.
@@ -130,19 +136,35 @@ func testAllMembersAdditiveArithmetic() {
130136
}
131137

132138
// Test type `AllMembersVectorProtocol` whose members conforms to `VectorProtocol`,
133-
// in which case we should make `TangentVector` and `TangentVector` conform to
134-
// `VectorProtocol`.
139+
// in which case we should make `AllDifferentiableVariables` and `TangentVector`
140+
// conform to `VectorProtocol`.
135141
struct MyVector : VectorProtocol, Differentiable {
136142
var w: Float
137143
var b: Float
138144
}
139145
struct AllMembersVectorProtocol : Differentiable {
140-
var w: MyVector
141-
var b: MyVector
146+
var v1: MyVector
147+
var v2: MyVector
142148
}
143149
func testAllMembersVectorProtocol() {
150+
assertConformsToVectorProtocol(AllMembersVectorProtocol.AllDifferentiableVariables.self)
144151
assertConformsToVectorProtocol(AllMembersVectorProtocol.TangentVector.self)
145-
assertConformsToVectorProtocol(AllMembersVectorProtocol.TangentVector.self)
152+
}
153+
154+
// Test type `AllMembersElementaryFunctions` whose members conforms to `ElementaryFunctions`,
155+
// in which case we should make `AllDifferentiableVariables` and `TangentVector`
156+
// conform to `ElementaryFunctions`.
157+
struct MyVector2 : ElementaryFunctions, Differentiable {
158+
var w: Float
159+
var b: Float
160+
}
161+
struct AllMembersElementaryFunctions : Differentiable {
162+
var v1: MyVector2
163+
var v2: MyVector2
164+
}
165+
func testAllMembersElementaryFunctions() {
166+
assertConformsToElementaryFunctions(AllMembersElementaryFunctions.AllDifferentiableVariables.self)
167+
assertConformsToElementaryFunctions(AllMembersElementaryFunctions.TangentVector.self)
146168
}
147169

148170
// Test type whose properties are not all differentiable.
@@ -154,6 +176,7 @@ struct DifferentiableSubset : Differentiable {
154176
}
155177
func testDifferentiableSubset() {
156178
assertConformsToAdditiveArithmetic(DifferentiableSubset.AllDifferentiableVariables.self)
179+
assertConformsToElementaryFunctions(DifferentiableSubset.AllDifferentiableVariables.self)
157180
assertConformsToVectorProtocol(DifferentiableSubset.AllDifferentiableVariables.self)
158181
assertAllDifferentiableVariablesEqualsTangentVector(DifferentiableSubset.self)
159182
_ = DifferentiableSubset.TangentVector(w: 1, b: 1)

0 commit comments

Comments
 (0)