Skip to content

Commit 9a9822e

Browse files
authored
[AutoDiff] Register derivative for AnyDerivative.init. (#24013)
`AnyDerivative.init` has type `(T) -> AnyDerivative`. Its pullback has type `(AnyDerivative) -> T.CotangentVector`.
1 parent 9244981 commit 9a9822e

File tree

2 files changed

+50
-0
lines changed

2 files changed

+50
-0
lines changed

stdlib/public/core/AutoDiff.swift

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -678,6 +678,7 @@ public struct AnyDerivative : Differentiable & AdditiveArithmetic {
678678
}
679679

680680
/// Creates a type-erased derivative from the given derivative.
681+
@differentiable(vjp: _vjpInit(_:))
681682
public init<T>(_ base: T)
682683
where T : Differentiable, T.TangentVector == T,
683684
T.AllDifferentiableVariables == T,
@@ -688,6 +689,18 @@ public struct AnyDerivative : Differentiable & AdditiveArithmetic {
688689
self._box = _ConcreteDerivativeBox<T>(base)
689690
}
690691

692+
@usableFromInline internal static func _vjpInit<T>(
693+
_ base: T
694+
) -> (AnyDerivative, (AnyDerivative) -> T.CotangentVector)
695+
where T : Differentiable, T.TangentVector == T,
696+
T.AllDifferentiableVariables == T,
697+
// NOTE: The requirement below should be defined on `Differentiable`.
698+
// But it causes a crash due to generic signature minimization bug.
699+
T.CotangentVector == T.CotangentVector.AllDifferentiableVariables
700+
{
701+
return (AnyDerivative(base), { v in v.base as! T.CotangentVector })
702+
}
703+
691704
public typealias TangentVector = AnyDerivative
692705
public typealias CotangentVector = AnyDerivative
693706
public typealias AllDifferentiableVariables = AnyDerivative

test/AutoDiff/anyderivative.swift

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ AnyDerivativeTests.test("Casting") {
7575
}
7676

7777
AnyDerivativeTests.test("Derivatives") {
78+
// Test `AnyDerivative` operations.
7879
func tripleSum(_ x: AnyDerivative, _ y: AnyDerivative) -> AnyDerivative {
7980
let sum = x + y
8081
return sum + sum + sum
@@ -112,6 +113,42 @@ AnyDerivativeTests.test("Derivatives") {
112113
expectEqual(expectedVJP, 𝛁x.base as? Generic<Double>.CotangentVector)
113114
expectEqual(expectedVJP, 𝛁y.base as? Generic<Double>.CotangentVector)
114115
}
116+
117+
// Test `AnyDerivative` initializer.
118+
func typeErased<T>(_ x: T) -> AnyDerivative
119+
where T : Differentiable, T.TangentVector == T,
120+
T.AllDifferentiableVariables == T,
121+
// NOTE: The requirement below should be defined on `Differentiable`.
122+
// But it causes a crash due to generic signature minimization bug.
123+
T.CotangentVector == T.CotangentVector.AllDifferentiableVariables
124+
{
125+
let any = AnyDerivative(x)
126+
return any + any
127+
}
128+
129+
do {
130+
let x: Float = 3
131+
let v = AnyDerivative(Float(1))
132+
let 𝛁x = pullback(at: x, in: { x in typeErased(x) })(v)
133+
let expectedVJP: Float = 2
134+
expectEqual(expectedVJP, 𝛁x)
135+
}
136+
137+
do {
138+
let x = Vector.TangentVector(x: 4, y: 5)
139+
let v = AnyDerivative(Vector.CotangentVector(x: 1, y: 1))
140+
let 𝛁x = pullback(at: x, in: { x in typeErased(x) })(v)
141+
let expectedVJP = Vector.CotangentVector(x: 2, y: 2)
142+
expectEqual(expectedVJP, 𝛁x)
143+
}
144+
145+
do {
146+
let x = Generic<Double>.TangentVector(x: 4)
147+
let v = AnyDerivative(Generic<Double>.CotangentVector(x: 1))
148+
let 𝛁x = pullback(at: x, in: { x in typeErased(x) })(v)
149+
let expectedVJP = Generic<Double>.CotangentVector(x: 2)
150+
expectEqual(expectedVJP, 𝛁x)
151+
}
115152
}
116153

117154
runAllTests()

0 commit comments

Comments
 (0)