Skip to content

Commit d4ff584

Browse files
committed
[AutoDiff upstream] Add AnyDerivative.
Add `AnyDerivative`, a type-erased wrapper for `Differentiable.TangentVector` associated type implementations.
1 parent 65ab642 commit d4ff584

File tree

3 files changed

+449
-0
lines changed

3 files changed

+449
-0
lines changed
Lines changed: 292 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,292 @@
1+
//===--- AnyDerivative.swift ----------------------------------*- swift -*-===//
2+
//
3+
// This source file is part of the Swift.org open source project
4+
//
5+
// Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors
6+
// Licensed under Apache License v2.0 with Runtime Library Exception
7+
//
8+
// See https://swift.org/LICENSE.txt for license information
9+
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
10+
//
11+
//===----------------------------------------------------------------------===//
12+
//
13+
// This file defines `AnyDerivative`, a type-erased wrapper for
14+
// `Differentiable.TangentVector` associated type implementations.
15+
//
16+
//===----------------------------------------------------------------------===//
17+
18+
import Swift
19+
20+
@usableFromInline
21+
internal protocol _AnyDerivativeBox {
22+
// `Equatable` requirements (implied by `AdditiveArithmetic`).
23+
func _isEqual(to other: _AnyDerivativeBox) -> Bool
24+
func _isNotEqual(to other: _AnyDerivativeBox) -> Bool
25+
26+
// `AdditiveArithmetic` requirements.
27+
static var _zero: _AnyDerivativeBox { get }
28+
func _adding(_ x: _AnyDerivativeBox) -> _AnyDerivativeBox
29+
func _subtracting(_ x: _AnyDerivativeBox) -> _AnyDerivativeBox
30+
31+
// `Differentiable` requirements.
32+
mutating func _move(along direction: _AnyDerivativeBox)
33+
34+
/// The underlying base value, type-erased to `Any`.
35+
var _typeErasedBase: Any { get }
36+
37+
/// Returns the underlying value unboxed to the given type, if possible.
38+
func _unboxed<U>(to type: U.Type) -> U?
39+
where U: Differentiable, U.TangentVector == U
40+
}
41+
42+
extension _AnyDerivativeBox {
43+
/// Returns true if the underlying value has type `AnyDerivative.OpaqueZero`.
44+
@inlinable
45+
func _isOpaqueZero() -> Bool {
46+
return _unboxed(to: AnyDerivative.OpaqueZero.self) != nil
47+
}
48+
}
49+
50+
@inline(never)
51+
@usableFromInline
52+
internal func _derivativeTypeMismatch(
53+
_ x: Any.Type, _ y: Any.Type, file: StaticString = #file, line: UInt = #line
54+
) -> Never {
55+
preconditionFailure(
56+
"""
57+
Derivative type mismatch: \
58+
\(String(reflecting: x)) and \(String(reflecting: y))
59+
""", file: file, line: line)
60+
}
61+
62+
@frozen
63+
@usableFromInline
64+
internal struct _ConcreteDerivativeBox<T>: _AnyDerivativeBox
65+
where T: Differentiable, T.TangentVector == T {
66+
/// The underlying base value.
67+
@usableFromInline
68+
var _base: T
69+
70+
@inlinable
71+
internal init(_ base: T) {
72+
self._base = base
73+
}
74+
75+
/// The underlying base value, type-erased to `Any`.
76+
@inlinable
77+
var _typeErasedBase: Any {
78+
return _base
79+
}
80+
81+
@inlinable
82+
func _unboxed<U>(to type: U.Type) -> U?
83+
where U: Differentiable, U.TangentVector == U {
84+
return (self as? _ConcreteDerivativeBox<U>)?._base
85+
}
86+
87+
// `Equatable` requirements (implied by `AdditiveArithmetic`).
88+
@inlinable
89+
func _isEqual(to other: _AnyDerivativeBox) -> Bool {
90+
return _base == other._unboxed(to: T.self)
91+
}
92+
@inlinable
93+
func _isNotEqual(to other: _AnyDerivativeBox) -> Bool {
94+
return _base != other._unboxed(to: T.self)
95+
}
96+
97+
// `AdditiveArithmetic` requirements.
98+
99+
@inlinable
100+
static var _zero: _AnyDerivativeBox {
101+
return _ConcreteDerivativeBox(T.zero)
102+
}
103+
104+
@inlinable
105+
func _adding(_ x: _AnyDerivativeBox) -> _AnyDerivativeBox {
106+
// 0 + x = x
107+
if _isOpaqueZero() {
108+
return x
109+
}
110+
// y + 0 = y
111+
if x._isOpaqueZero() {
112+
return self
113+
}
114+
guard let xBase = x._unboxed(to: T.self) else {
115+
_derivativeTypeMismatch(T.self, type(of: x._typeErasedBase))
116+
}
117+
return _ConcreteDerivativeBox(_base + xBase)
118+
}
119+
120+
@inlinable
121+
func _subtracting(_ x: _AnyDerivativeBox) -> _AnyDerivativeBox {
122+
// y - 0 = y
123+
if x._isOpaqueZero() {
124+
return self
125+
}
126+
// 0 - x = -x
127+
if _isOpaqueZero() {
128+
return type(of: x)._zero._subtracting(x)
129+
}
130+
guard let xBase = x._unboxed(to: T.self) else {
131+
_derivativeTypeMismatch(T.self, type(of: x._typeErasedBase))
132+
}
133+
return _ConcreteDerivativeBox(_base - xBase)
134+
}
135+
136+
// `Differentiable` requirements.
137+
@inlinable
138+
mutating func _move(along direction: _AnyDerivativeBox) {
139+
if direction._isOpaqueZero() {
140+
return
141+
}
142+
// The case where `self._isOpaqueZero()` returns true is handled in
143+
// `AnyDerivative.move(along:)`.
144+
guard
145+
let directionBase =
146+
direction._unboxed(to: T.TangentVector.self)
147+
else {
148+
_derivativeTypeMismatch(T.self, type(of: direction._typeErasedBase))
149+
}
150+
_base.move(along: directionBase)
151+
}
152+
}
153+
154+
/// A type-erased derivative value.
155+
///
156+
/// The `AnyDerivative` type forwards its operations to an arbitrary underlying
157+
/// base derivative value conforming to `Differentiable` and
158+
/// `AdditiveArithmetic`, hiding the specifics of the underlying value.
159+
@frozen
160+
public struct AnyDerivative: Differentiable & AdditiveArithmetic {
161+
@usableFromInline
162+
internal var _box: _AnyDerivativeBox
163+
164+
@inlinable
165+
internal init(_box: _AnyDerivativeBox) {
166+
self._box = _box
167+
}
168+
169+
/// The underlying base value.
170+
@inlinable
171+
public var base: Any {
172+
return _box._typeErasedBase
173+
}
174+
175+
/// Creates a type-erased derivative from the given derivative.
176+
@inlinable
177+
@differentiable
178+
public init<T>(_ base: T) where T: Differentiable, T.TangentVector == T {
179+
self._box = _ConcreteDerivativeBox<T>(base)
180+
}
181+
182+
@inlinable
183+
@derivative(of: init)
184+
internal static func _vjpInit<T>(
185+
_ base: T
186+
) -> (value: AnyDerivative, pullback: (AnyDerivative) -> T.TangentVector)
187+
where T: Differentiable, T.TangentVector == T {
188+
return (AnyDerivative(base), { v in v.base as! T.TangentVector })
189+
}
190+
191+
@inlinable
192+
@derivative(of: init)
193+
internal static func _jvpInit<T>(
194+
_ base: T
195+
) -> (value: AnyDerivative, differential: (T.TangentVector) -> AnyDerivative)
196+
where T: Differentiable, T.TangentVector == T {
197+
return (AnyDerivative(base), { dbase in AnyDerivative(dbase) })
198+
}
199+
200+
public typealias TangentVector = AnyDerivative
201+
202+
// `Equatable` requirements (implied by `AdditiveArithmetic`).
203+
@inlinable
204+
public static func == (lhs: AnyDerivative, rhs: AnyDerivative) -> Bool {
205+
return lhs._box._isEqual(to: rhs._box)
206+
}
207+
@inlinable
208+
public static func != (lhs: AnyDerivative, rhs: AnyDerivative) -> Bool {
209+
return lhs._box._isNotEqual(to: rhs._box)
210+
}
211+
212+
// `AdditiveArithmetic` requirements.
213+
214+
/// Internal struct representing an opaque zero value.
215+
@frozen
216+
@usableFromInline
217+
internal struct OpaqueZero: Differentiable & AdditiveArithmetic {}
218+
219+
@inlinable
220+
public static var zero: AnyDerivative {
221+
return AnyDerivative(
222+
_box: _ConcreteDerivativeBox<OpaqueZero>(OpaqueZero.zero))
223+
}
224+
225+
@inlinable
226+
public static func + (
227+
lhs: AnyDerivative, rhs: AnyDerivative
228+
) -> AnyDerivative {
229+
return AnyDerivative(_box: lhs._box._adding(rhs._box))
230+
}
231+
232+
@derivative(of: +)
233+
@inlinable
234+
internal static func _vjpAdd(
235+
lhs: AnyDerivative, rhs: AnyDerivative
236+
) -> (
237+
value: AnyDerivative,
238+
pullback: (AnyDerivative) -> (AnyDerivative, AnyDerivative)
239+
) {
240+
return (lhs + rhs, { v in (v, v) })
241+
}
242+
243+
@derivative(of: +)
244+
@inlinable
245+
internal static func _jvpAdd(
246+
lhs: AnyDerivative, rhs: AnyDerivative
247+
) -> (
248+
value: AnyDerivative,
249+
differential: (AnyDerivative, AnyDerivative) -> (AnyDerivative)
250+
) {
251+
return (lhs + rhs, { (dlhs, drhs) in dlhs + drhs })
252+
}
253+
254+
@inlinable
255+
public static func - (
256+
lhs: AnyDerivative, rhs: AnyDerivative
257+
) -> AnyDerivative {
258+
return AnyDerivative(_box: lhs._box._subtracting(rhs._box))
259+
}
260+
261+
@derivative(of: -)
262+
@inlinable
263+
internal static func _vjpSubtract(
264+
lhs: AnyDerivative, rhs: AnyDerivative
265+
) -> (
266+
value: AnyDerivative,
267+
pullback: (AnyDerivative) -> (AnyDerivative, AnyDerivative)
268+
) {
269+
return (lhs - rhs, { v in (v, .zero - v) })
270+
}
271+
272+
@derivative(of: -)
273+
@inlinable
274+
internal static func _jvpSubtract(
275+
lhs: AnyDerivative, rhs: AnyDerivative
276+
) -> (
277+
value: AnyDerivative,
278+
differential: (AnyDerivative, AnyDerivative) -> AnyDerivative
279+
) {
280+
return (lhs - rhs, { (dlhs, drhs) in dlhs - drhs })
281+
}
282+
283+
// `Differentiable` requirements.
284+
@inlinable
285+
public mutating func move(along direction: TangentVector) {
286+
if _box._isOpaqueZero() {
287+
_box = direction._box
288+
return
289+
}
290+
_box._move(along: direction._box)
291+
}
292+
}

stdlib/public/Differentiation/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ add_swift_target_library(swift_Differentiation ${SWIFT_STDLIB_LIBRARY_BUILD_TYPE
1414
Differentiable.swift
1515
DifferentialOperators.swift
1616
DifferentiationUtilities.swift
17+
AnyDerivative.swift
1718
ArrayDifferentiation.swift
1819

1920
GYB_SOURCES

0 commit comments

Comments
 (0)