Skip to content

Commit f34c92f

Browse files
authored
[AutoDiff] Derive 'EuclideanDifferentiable' vector view from members's vector views (#26890)
[TF-785](https://bugs.swift.org/browse/TF-785)
1 parent 7dc9191 commit f34c92f

11 files changed

+95
-53
lines changed

include/swift/AST/KnownIdentifiers.def

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ IDENTIFIER(x)
157157
// Differentiable
158158
IDENTIFIER(TangentVector)
159159
IDENTIFIER(move)
160-
IDENTIFIER(vectorView)
160+
IDENTIFIER(differentiableVectorView)
161161

162162
// Kinds of layout constraints
163163
IDENTIFIER_WITH_NAME(UnknownLayout, "_UnknownLayout")

lib/Sema/DerivedConformanceDifferentiable.cpp

Lines changed: 36 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,8 @@ bool DerivedConformance::canDeriveEuclideanDifferentiable(
205205
return false;
206206
auto &C = nominal->getASTContext();
207207
auto *lazyResolver = C.getLazyResolver();
208-
auto *addArithProto = C.getProtocol(KnownProtocolKind::AdditiveArithmetic);
208+
auto *eucDiffProto =
209+
C.getProtocol(KnownProtocolKind::EuclideanDifferentiable);
209210
// Return true if all differentiation stored properties conform to
210211
// `AdditiveArithmetic` and their `TangentVector` equals themselves.
211212
SmallVector<VarDecl *, 16> diffProperties;
@@ -216,10 +217,8 @@ bool DerivedConformance::canDeriveEuclideanDifferentiable(
216217
if (!member->hasInterfaceType())
217218
return false;
218219
auto varType = DC->mapTypeIntoContext(member->getValueInterfaceType());
219-
if (!TypeChecker::conformsToProtocol(varType, addArithProto, DC, None))
220-
return false;
221-
auto memberAssocType = getTangentVectorType(member, DC);
222-
return member->getType()->isEqual(memberAssocType);
220+
return (bool)TypeChecker::conformsToProtocol(
221+
varType, eucDiffProto, DC, None);
223222
});
224223
}
225224

@@ -370,8 +369,8 @@ static ValueDecl *deriveDifferentiable_move(DerivedConformance &derived) {
370369
{deriveBodyDifferentiable_move, nullptr});
371370
}
372371

373-
/// Synthesize the `vectorView` property declaration.
374-
static ValueDecl *deriveEuclideanDifferentiable_vectorView(
372+
/// Synthesize the `differentiableVectorView` property declaration.
373+
static ValueDecl *deriveEuclideanDifferentiable_differentiableVectorView(
375374
DerivedConformance &derived) {
376375
auto &C = derived.TC.Context;
377376
auto *parentDC = derived.getConformanceContext();
@@ -383,8 +382,8 @@ static ValueDecl *deriveEuclideanDifferentiable_vectorView(
383382
VarDecl *vectorViewDecl;
384383
PatternBindingDecl *pbDecl;
385384
std::tie(vectorViewDecl, pbDecl) = derived.declareDerivedProperty(
386-
C.Id_vectorView, tangentType, tangentContextualType, /*isStatic*/ false,
387-
/*isFinal*/ true);
385+
C.Id_differentiableVectorView, tangentType, tangentContextualType,
386+
/*isStatic*/ false, /*isFinal*/ true);
388387

389388
struct GetterSynthesizerContext {
390389
StructDecl *tangentDecl;
@@ -397,7 +396,13 @@ static ValueDecl *deriveEuclideanDifferentiable_vectorView(
397396
assert(context && "Invalid context");
398397
auto *parentDC = getterDecl->getParent();
399398
auto *nominal = parentDC->getSelfNominalTypeDecl();
399+
auto *module = nominal->getModuleContext();
400400
auto &C = nominal->getASTContext();
401+
auto *eucDiffProto =
402+
C.getProtocol(KnownProtocolKind::EuclideanDifferentiable);
403+
auto *vectorViewReq =
404+
eucDiffProto->lookupDirect(C.Id_differentiableVectorView).front();
405+
401406
SmallVector<VarDecl *, 8> diffProperties;
402407
getStoredPropertiesForDifferentiation(nominal, nominal->getDeclContext(),
403408
diffProperties);
@@ -419,20 +424,32 @@ static ValueDecl *deriveEuclideanDifferentiable_vectorView(
419424

420425
// Create a call:
421426
// TangentVector.init(
422-
// <property_name_1...>: self.<property_name_1>,
423-
// <property_name_2...>: self.<property_name_2>,
427+
// <property_name_1...>:
428+
// self.differentiableVectorView.<property_name_1>,
429+
// <property_name_2...>:
430+
// self.differentiableVectorView.<property_name_2>,
424431
// ...
425432
// )
426433
SmallVector<Identifier, 8> argLabels;
427434
SmallVector<Expr *, 8> memberRefs;
428-
auto *selfDRE = new (C) DeclRefExpr(getterDecl->getImplicitSelfDecl(),
429-
DeclNameLoc(),
430-
/*Implicit*/ true);
431435
for (auto *member : diffProperties) {
436+
auto *selfDRE = new (C) DeclRefExpr(getterDecl->getImplicitSelfDecl(),
437+
DeclNameLoc(),
438+
/*Implicit*/ true);
439+
auto *memberExpr = new (C) MemberRefExpr(
440+
selfDRE, SourceLoc(), member, DeclNameLoc(), /*Implicit*/ true);
441+
auto memberType =
442+
parentDC->mapTypeIntoContext(member->getValueInterfaceType());
443+
auto confRef = module->lookupConformance(memberType, eucDiffProto);
444+
assert(confRef &&
445+
"Member missing conformance to `EuclideanDifferentiable`");
446+
ConcreteDeclRef memberDeclRef = vectorViewReq;
447+
if (confRef->isConcrete())
448+
memberDeclRef = confRef->getConcrete()->getWitnessDecl(vectorViewReq);
432449
argLabels.push_back(member->getName());
433-
memberRefs.push_back(
434-
new (C) MemberRefExpr(selfDRE, SourceLoc(), member, DeclNameLoc(),
435-
/*Implicit*/ true));
450+
memberRefs.push_back(new (C) MemberRefExpr(
451+
memberExpr, SourceLoc(), memberDeclRef, DeclNameLoc(),
452+
/*Implicit*/ true));
436453
}
437454
assert(memberRefs.size() == argLabels.size());
438455
CallExpr *callExpr =
@@ -875,8 +892,8 @@ ValueDecl *DerivedConformance::deriveEuclideanDifferentiable(
875892
// Diagnose conformances in disallowed contexts.
876893
if (checkAndDiagnoseDisallowedContext(requirement))
877894
return nullptr;
878-
if (requirement->getFullName() == TC.Context.Id_vectorView)
879-
return deriveEuclideanDifferentiable_vectorView(*this);
895+
if (requirement->getFullName() == TC.Context.Id_differentiableVectorView)
896+
return deriveEuclideanDifferentiable_differentiableVectorView(*this);
880897
TC.diagnose(requirement->getLoc(),
881898
diag::broken_euclidean_differentiable_requirement);
882899
return nullptr;

lib/Sema/DerivedConformances.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -260,8 +260,8 @@ ValueDecl *DerivedConformance::getDerivableRequirement(NominalTypeDecl *nominal,
260260
return getRequirement(KnownProtocolKind::AdditiveArithmetic);
261261

262262
// SWIFT_ENABLE_TENSORFLOW
263-
// EuclideanDifferentiable.vectorView
264-
if (name.isSimpleName(ctx.Id_vectorView))
263+
// EuclideanDifferentiable.differentiableVectorView
264+
if (name.isSimpleName(ctx.Id_differentiableVectorView))
265265
return getRequirement(KnownProtocolKind::EuclideanDifferentiable);
266266

267267
// SWIFT_ENABLE_TENSORFLOW

stdlib/public/core/Array.swift

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1965,6 +1965,14 @@ extension Array where Element : Differentiable {
19651965
}
19661966
}
19671967

1968+
extension Array.DifferentiableView : EuclideanDifferentiable
1969+
where Element : EuclideanDifferentiable {
1970+
public var differentiableVectorView: Array.DifferentiableView.TangentVector {
1971+
Array.DifferentiableView.TangentVector(
1972+
base.map { $0.differentiableVectorView })
1973+
}
1974+
}
1975+
19681976
extension Array.DifferentiableView : Equatable where Element : Equatable {
19691977
public static func == (
19701978
lhs: Array.DifferentiableView,
@@ -2061,6 +2069,13 @@ extension Array : Differentiable where Element : Differentiable {
20612069
}
20622070
}
20632071

2072+
extension Array : EuclideanDifferentiable
2073+
where Element : EuclideanDifferentiable {
2074+
public var differentiableVectorView: TangentVector {
2075+
TangentVector(map { $0.differentiableVectorView })
2076+
}
2077+
}
2078+
20642079
extension Array where Element : Differentiable {
20652080
public func _vjpSubscript(index: Int) ->
20662081
(Element, (Element.TangentVector) -> TangentVector)

stdlib/public/core/AutoDiff.swift

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -176,14 +176,16 @@ public protocol Differentiable {
176176
""")
177177
var zeroTangentVector: TangentVector { get }
178178

179-
@available(*, deprecated,
180-
message: "'AllDifferentiableVariables' is now equal to 'Self' and will be removed")
179+
@available(*, deprecated, message: """
180+
'AllDifferentiableVariables' is now equal to 'Self' and will be removed
181+
""")
181182
typealias AllDifferentiableVariables = Self
182183
}
183184

184185
public extension Differentiable {
185-
@available(*, deprecated,
186-
message: "'allDifferentiableVariables' is now equal to 'self' and will be removed")
186+
@available(*, deprecated, message: """
187+
'allDifferentiableVariables' is now equal to 'self' and will be removed
188+
""")
187189
var allDifferentiableVariables: AllDifferentiableVariables {
188190
get { return self }
189191
set { self = newValue }
@@ -204,8 +206,9 @@ public extension Differentiable where TangentVector == Self {
204206
}
205207
}
206208

207-
/// A type that consists of a differentiable vector space and some other
208-
/// non-differentiable component.
209+
/// A type that is differentiable in the Euclidean space.
210+
/// The type may represent a vector space, or consist of a vector space and some
211+
/// other non-differentiable component.
209212
///
210213
/// Mathematically, this represents a product manifold that consists of
211214
/// a differentiable vector space and some arbitrary manifold, where the tangent
@@ -229,11 +232,11 @@ public extension Differentiable where TangentVector == Self {
229232
/// `TangentVector` is equal to its vector space component.
230233
public protocol EuclideanDifferentiable: Differentiable {
231234
/// The differentiable vector component of `self`.
232-
var vectorView: TangentVector { get }
235+
var differentiableVectorView: TangentVector { get }
233236
}
234237

235238
public extension EuclideanDifferentiable where TangentVector == Self {
236-
var vectorView: TangentVector { _read { yield self } }
239+
var differentiableVectorView: TangentVector { _read { yield self } }
237240
}
238241

239242
/// Returns `x` like an identity function. When used in a context where `x` is
@@ -776,6 +779,9 @@ internal protocol _AnyDerivativeBox {
776779
// `Differentiable` requirements.
777780
mutating func _move(along direction: _AnyDerivativeBox)
778781

782+
// `EuclideanDifferentiable` requirements.
783+
var _differentiableVectorView: _AnyDerivativeBox { get }
784+
779785
/// The underlying base value, type-erased to `Any`.
780786
var _typeErasedBase: Any { get }
781787

@@ -883,14 +889,19 @@ internal struct _ConcreteDerivativeBox<T> : _AnyDerivativeBox
883889
}
884890
_base.move(along: directionBase)
885891
}
892+
893+
// `EuclideanDifferentiable` requirements.
894+
var _differentiableVectorView: _AnyDerivativeBox {
895+
return self
896+
}
886897
}
887898

888899
/// A type-erased derivative value.
889900
///
890901
/// The `AnyDerivative` type forwards its operations to an arbitrary underlying
891902
/// base derivative value conforming to `Differentiable` and
892903
/// `AdditiveArithmetic`, hiding the specifics of the underlying value.
893-
public struct AnyDerivative : Differentiable & AdditiveArithmetic {
904+
public struct AnyDerivative : EuclideanDifferentiable & AdditiveArithmetic {
894905
internal var _box: _AnyDerivativeBox
895906

896907
internal init(_box: _AnyDerivativeBox) {
@@ -931,7 +942,7 @@ public struct AnyDerivative : Differentiable & AdditiveArithmetic {
931942
/// Internal struct representing an opaque zero value.
932943
@frozen
933944
@usableFromInline
934-
internal struct OpaqueZero : Differentiable & AdditiveArithmetic {}
945+
internal struct OpaqueZero : EuclideanDifferentiable & AdditiveArithmetic {}
935946

936947
public static var zero: AnyDerivative {
937948
return AnyDerivative(
@@ -974,6 +985,11 @@ public struct AnyDerivative : Differentiable & AdditiveArithmetic {
974985
}
975986
_box._move(along: direction._box)
976987
}
988+
989+
// `EuclideanDifferentiable` requirements.
990+
public var differentiableVectorView: TangentVector {
991+
return self
992+
}
977993
}
978994

979995
//===----------------------------------------------------------------------===//

stdlib/public/core/FloatingPointTypes.swift.gyb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1302,7 +1302,7 @@ extension ${Self} : VectorProtocol {
13021302
}
13031303
}
13041304

1305-
extension ${Self} : Differentiable {
1305+
extension ${Self} : EuclideanDifferentiable {
13061306
public typealias TangentVector = ${Self}
13071307

13081308
public mutating func move(along direction: TangentVector) {

stdlib/public/core/SIMDVectorTypes.swift.gyb

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ public struct SIMD${n}<Scalar>: SIMD where Scalar: SIMDScalar {
4747
/// Accesses the scalar at the specified position.
4848
// SWIFT_ENABLE_TENSORFLOW
4949
@differentiable(vjp: _vjpSubscript
50-
where Scalar : Differentiable & BinaryFloatingPoint,
50+
where Scalar : EuclideanDifferentiable & BinaryFloatingPoint,
5151
Scalar.TangentVector : BinaryFloatingPoint)
5252
public subscript(index: Int) -> Scalar {
5353
@_transparent get {
@@ -192,14 +192,14 @@ extension SIMD${n} where Scalar: BinaryFloatingPoint {
192192
// SWIFT_ENABLE_TENSORFLOW
193193
extension SIMD${n} : AdditiveArithmetic where Scalar : FloatingPoint {}
194194

195-
extension SIMD${n} : Differentiable
196-
where Scalar : Differentiable & BinaryFloatingPoint,
195+
extension SIMD${n} : Differentiable & EuclideanDifferentiable
196+
where Scalar : EuclideanDifferentiable & BinaryFloatingPoint,
197197
Scalar.TangentVector : BinaryFloatingPoint {
198198
public typealias TangentVector = SIMD${n}
199199
}
200200

201201
extension SIMD${n}
202-
where Scalar : Differentiable & BinaryFloatingPoint,
202+
where Scalar : EuclideanDifferentiable & BinaryFloatingPoint,
203203
Scalar.TangentVector : BinaryFloatingPoint {
204204
@usableFromInline
205205
internal func _vjpSubscript(index: Int)

test/AutoDiff/derived_differentiable.swift

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ public struct Foo : EuclideanDifferentiable {
1414
// CHECK-AST: internal init(a: Float)
1515
// CHECK-AST: public struct TangentVector
1616
// CHECK-AST: public typealias TangentVector = Foo.TangentVector
17-
// CHECK-AST: public var vectorView: Foo.TangentVector { get }
17+
// CHECK-AST: public var differentiableVectorView: Foo.TangentVector { get }
1818

1919
// CHECK-SILGEN-LABEL: // Foo.a.getter
2020
// CHECK-SILGEN-NEXT: sil [transparent] [serialized] [differentiable source 0 wrt 0] [ossa] @$s22derived_differentiable3FooV1aSfvg : $@convention(method) (Foo) -> Float
@@ -32,8 +32,8 @@ let _: @differentiable (AdditiveTangentIsSelf) -> Float = { x in
3232
// CHECK-AST: internal var dummy: PointwiseMultiplicativeDummy
3333
// CHECK-AST: internal init(a: Float, dummy: PointwiseMultiplicativeDummy)
3434
// CHECK-AST: internal typealias TangentVector = AdditiveTangentIsSelf
35-
// The following should not exist because when `Self == Self.TangentVector`, `vectorView` is not synthesized.
36-
// CHECK-AST-NOT: internal var vectorView: AdditiveTangentIsSelf { get }
35+
// The following should not exist because when `Self == Self.TangentVector`, `differentiableVectorView` is not synthesized.
36+
// CHECK-AST-NOT: internal var differentiableVectorView: AdditiveTangentIsSelf { get }
3737

3838
struct TestNoDerivative : EuclideanDifferentiable {
3939
var w: Float
@@ -46,7 +46,7 @@ struct TestNoDerivative : EuclideanDifferentiable {
4646
// CHECK-AST: internal init(w: Float, technicallyDifferentiable: Float)
4747
// CHECK-AST: internal struct TangentVector : Differentiable, AdditiveArithmetic, ElementaryFunctions, VectorProtocol
4848
// CHECK-AST: internal typealias TangentVector = TestNoDerivative.TangentVector
49-
// CHECK-AST: internal var vectorView: TestNoDerivative.TangentVector { get }
49+
// CHECK-AST: internal var differentiableVectorView: TestNoDerivative.TangentVector { get }
5050

5151
struct TestPointwiseMultiplicative : Differentiable {
5252
var w: PointwiseMultiplicativeDummy

test/AutoDiff/derived_differentiable_runtime.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ DerivedConformanceTests.test("EuclideanVectorView") {
3838
init() { x = [1, 2, 3, 4]; y = .zero }
3939
}
4040
let x = Foo()
41-
expectEqual(Foo.TangentVector(x: [1, 2, 3, 4]), x.vectorView)
41+
expectEqual(Foo.TangentVector(x: [1, 2, 3, 4]), x.differentiableVectorView)
4242
}
4343
do {
4444
class FooClass: EuclideanDifferentiable {
@@ -47,7 +47,7 @@ DerivedConformanceTests.test("EuclideanVectorView") {
4747
init() { x = [1, 2, 3, 4]; y = .zero }
4848
}
4949
let x = FooClass()
50-
expectEqual(FooClass.TangentVector(x: [1, 2, 3, 4]), x.vectorView)
50+
expectEqual(FooClass.TangentVector(x: [1, 2, 3, 4]), x.differentiableVectorView)
5151
}
5252
}
5353

test/Sema/class_differentiable.swift

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -226,9 +226,6 @@ struct MyVector2 : ElementaryFunctions, Differentiable, EuclideanDifferentiable
226226
self.b = b
227227
}
228228
}
229-
// Won't derive `EuclideanDifferentiable` because `MyVector2.TangentVector != MyVector2`.
230-
// expected-error @+2 {{type 'AllMembersElementaryFunctions' does not conform to protocol 'EuclideanDifferentiable'}}
231-
// expected-note @+1 {{do you want to add protocol stubs?}}
232229
class AllMembersElementaryFunctions : Differentiable, EuclideanDifferentiable {
233230
var v1: MyVector2
234231
var v2: MyVector2

test/Sema/struct_differentiable.swift

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ struct GenericVectorSpacesEqualSelf<T>
9797
func testGenericVectorSpacesEqualSelf() {
9898
var genericSame = GenericVectorSpacesEqualSelf<Double>(w: 1, b: 1)
9999
genericSame.move(along: genericSame)
100-
genericSame.move(along: genericSame.vectorView)
100+
genericSame.move(along: genericSame.differentiableVectorView)
101101
}
102102

103103
// Test nested type.
@@ -130,7 +130,7 @@ struct AllMembersAdditiveArithmetic : Differentiable, EuclideanDifferentiable {
130130

131131
// Test type `AllMembersVectorProtocol` whose members conforms to `VectorProtocol`,
132132
// in which case we should make `TangentVector` conform to `VectorProtocol`.
133-
struct MyVector : VectorProtocol, Differentiable {
133+
struct MyVector : VectorProtocol, Differentiable, EuclideanDifferentiable {
134134
var w: Float
135135
var b: Float
136136
}
@@ -149,9 +149,6 @@ struct MyVector2 : ElementaryFunctions, Differentiable, EuclideanDifferentiable
149149
var b: Float
150150
}
151151

152-
// Won't derive `EuclideanDifferentiable` because `MyVector2.TangentVector != MyVector2`.
153-
// expected-error @+2 {{type 'AllMembersElementaryFunctions' does not conform to protocol 'EuclideanDifferentiable'}}
154-
// expected-note @+1 {{do you want to add protocol stubs?}}
155152
struct AllMembersElementaryFunctions : Differentiable, EuclideanDifferentiable {
156153
var v1: MyVector2
157154
var v2: MyVector2
@@ -186,8 +183,8 @@ struct EuclideanDifferentiableSubset : EuclideanDifferentiable {
186183
func testEuclideanDifferentiableSubset() {
187184
let x = EuclideanDifferentiableSubset(w: 1, b: 2, flag: false)
188185
let tan = EuclideanDifferentiableSubset.TangentVector(w: 1, b: 1)
189-
_ = x.vectorView.w * tan.w
190-
_ = x.vectorView.b * tan.b
186+
_ = x.differentiableVectorView.w * tan.w
187+
_ = x.differentiableVectorView.b * tan.b
191188

192189
_ = pullback(at: x) { model in
193190
model.w + model.b

0 commit comments

Comments
 (0)