Skip to content

Commit c76e7bb

Browse files
authored
---
yaml --- r: 312279 b: refs/heads/tensorflow-merge c: 9a9822e h: refs/heads/master i: 312277: 4c3eea1 312275: 6b84b95 312271: d2e6ebe
1 parent 4d72db4 commit c76e7bb

File tree

3 files changed

+51
-1
lines changed

3 files changed

+51
-1
lines changed

[refs]

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1379,7 +1379,7 @@ refs/heads/chase-my-tail: 8bb91443a9e81bbfac92a2621a0af887a1da8dbf
13791379
refs/heads/consider-outer-alternatives: 708bac749ec60a22a79e2eefbe734f9488a7370d
13801380
refs/heads/revert-25740-oops-i-linked-it-again: fdd41aeb682fc488572bdc1cf71b2ff6997ba576
13811381
refs/heads/swift-5.1-branch-06-12-2019: e63b7b2d3b93c48232d386099d0ec525d21d8f8d
1382-
refs/heads/tensorflow-merge: 92449819bc98bb921402d9a5b8c41199d4e37038
1382+
refs/heads/tensorflow-merge: 9a9822ed7185848f5de55e4983fa13e7094dfea1
13831383
refs/heads/update-checkout-sha-info: 5832743c5c2a842976c42a508a4c6dcceefb0aef
13841384
refs/tags/swift-5.1-DEVELOPMENT-SNAPSHOT-2019-06-12-a: 228f0448d9bb909aacbba4afcb7c600a405d15da
13851385
refs/tags/swift-5.1-DEVELOPMENT-SNAPSHOT-2019-06-14-a: 922861a77b5fc2bf46bc917da70ceb15eef76836

branches/tensorflow-merge/stdlib/public/core/AutoDiff.swift

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -678,6 +678,7 @@ public struct AnyDerivative : Differentiable & AdditiveArithmetic {
678678
}
679679

680680
/// Creates a type-erased derivative from the given derivative.
681+
@differentiable(vjp: _vjpInit(_:))
681682
public init<T>(_ base: T)
682683
where T : Differentiable, T.TangentVector == T,
683684
T.AllDifferentiableVariables == T,
@@ -688,6 +689,18 @@ public struct AnyDerivative : Differentiable & AdditiveArithmetic {
688689
self._box = _ConcreteDerivativeBox<T>(base)
689690
}
690691

692+
@usableFromInline internal static func _vjpInit<T>(
693+
_ base: T
694+
) -> (AnyDerivative, (AnyDerivative) -> T.CotangentVector)
695+
where T : Differentiable, T.TangentVector == T,
696+
T.AllDifferentiableVariables == T,
697+
// NOTE: The requirement below should be defined on `Differentiable`.
698+
// But it causes a crash due to generic signature minimization bug.
699+
T.CotangentVector == T.CotangentVector.AllDifferentiableVariables
700+
{
701+
return (AnyDerivative(base), { v in v.base as! T.CotangentVector })
702+
}
703+
691704
public typealias TangentVector = AnyDerivative
692705
public typealias CotangentVector = AnyDerivative
693706
public typealias AllDifferentiableVariables = AnyDerivative

branches/tensorflow-merge/test/AutoDiff/anyderivative.swift

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ AnyDerivativeTests.test("Casting") {
7575
}
7676

7777
AnyDerivativeTests.test("Derivatives") {
78+
// Test `AnyDerivative` operations.
7879
func tripleSum(_ x: AnyDerivative, _ y: AnyDerivative) -> AnyDerivative {
7980
let sum = x + y
8081
return sum + sum + sum
@@ -112,6 +113,42 @@ AnyDerivativeTests.test("Derivatives") {
112113
expectEqual(expectedVJP, 𝛁x.base as? Generic<Double>.CotangentVector)
113114
expectEqual(expectedVJP, 𝛁y.base as? Generic<Double>.CotangentVector)
114115
}
116+
117+
// Test `AnyDerivative` initializer.
118+
func typeErased<T>(_ x: T) -> AnyDerivative
119+
where T : Differentiable, T.TangentVector == T,
120+
T.AllDifferentiableVariables == T,
121+
// NOTE: The requirement below should be defined on `Differentiable`.
122+
// But it causes a crash due to generic signature minimization bug.
123+
T.CotangentVector == T.CotangentVector.AllDifferentiableVariables
124+
{
125+
let any = AnyDerivative(x)
126+
return any + any
127+
}
128+
129+
do {
130+
let x: Float = 3
131+
let v = AnyDerivative(Float(1))
132+
let 𝛁x = pullback(at: x, in: { x in typeErased(x) })(v)
133+
let expectedVJP: Float = 2
134+
expectEqual(expectedVJP, 𝛁x)
135+
}
136+
137+
do {
138+
let x = Vector.TangentVector(x: 4, y: 5)
139+
let v = AnyDerivative(Vector.CotangentVector(x: 1, y: 1))
140+
let 𝛁x = pullback(at: x, in: { x in typeErased(x) })(v)
141+
let expectedVJP = Vector.CotangentVector(x: 2, y: 2)
142+
expectEqual(expectedVJP, 𝛁x)
143+
}
144+
145+
do {
146+
let x = Generic<Double>.TangentVector(x: 4)
147+
let v = AnyDerivative(Generic<Double>.CotangentVector(x: 1))
148+
let 𝛁x = pullback(at: x, in: { x in typeErased(x) })(v)
149+
let expectedVJP = Generic<Double>.CotangentVector(x: 2)
150+
expectEqual(expectedVJP, 𝛁x)
151+
}
115152
}
116153

117154
runAllTests()

0 commit comments

Comments
 (0)