Skip to content

Commit ddef929

Browse files
author
marcrasi
authored
[AutoDiff upstream] DifferentiationUnittest and some e2e tests (#30915)
Adds 2 simple e2e tests and some lit subsitutions and unittest libraries necessary to support them.
1 parent 1d7b3a5 commit ddef929

File tree

6 files changed

+825
-1
lines changed

6 files changed

+825
-1
lines changed

stdlib/private/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@ endif()
55
if(SWIFT_BUILD_SDK_OVERLAY)
66
# SwiftPrivateThreadExtras makes use of Darwin/Glibc, which is part of the
77
# SDK overlay. It can't be built separately from the SDK overlay.
8+
if(SWIFT_ENABLE_EXPERIMENTAL_DIFFERENTIABLE_PROGRAMMING)
9+
add_subdirectory(DifferentiationUnittest)
10+
endif()
811
add_subdirectory(RuntimeUnittest)
912
add_subdirectory(StdlibUnicodeUnittest)
1013
add_subdirectory(StdlibCollectionUnittest)
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
add_swift_target_library(swiftDifferentiationUnittest ${SWIFT_STDLIB_LIBRARY_BUILD_TYPES} IS_STDLIB
2+
# This file should be listed first. Module name is inferred from the filename.
3+
DifferentiationUnittest.swift
4+
5+
SWIFT_MODULE_DEPENDS _Differentiation StdlibUnittest
6+
INSTALL_IN_COMPONENT stdlib-experimental
7+
DARWIN_INSTALL_NAME_DIR "${SWIFT_DARWIN_STDLIB_PRIVATE_INSTALL_NAME_DIR}")
Lines changed: 304 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,304 @@
1+
//===--- DifferentiationUnittest.swift ------------------------------------===//
2+
//
3+
// This source file is part of the Swift.org open source project
4+
//
5+
// Copyright (c) 2020 Apple Inc. and the Swift project authors
6+
// Licensed under Apache License v2.0 with Runtime Library Exception
7+
//
8+
// See https://swift.org/LICENSE.txt for license information
9+
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
@_exported import _Differentiation
14+
import StdlibUnittest
15+
16+
public enum _GlobalLeakCount {
17+
public static var count = 0
18+
}
19+
20+
/// Execute body and check expected leak count.
21+
public func withLeakChecking(
22+
expectedLeakCount: Int = 0, file: String = #file, line: UInt = #line,
23+
_ body: () -> Void
24+
) {
25+
// Note: compare expected leak count with relative leak count after
26+
// running `body`.
27+
// This approach is more robust than comparing leak count with zero
28+
// and resetting leak count to zero, which is stateful and causes issues.
29+
let beforeLeakCount = _GlobalLeakCount.count
30+
body()
31+
let leakCount = _GlobalLeakCount.count - beforeLeakCount
32+
expectEqual(
33+
expectedLeakCount, leakCount, "Leaks detected: \(leakCount)",
34+
file: file, line: line)
35+
}
36+
37+
public extension TestSuite {
38+
/// Execute test function and check expected leak count.
39+
func testWithLeakChecking(
40+
_ name: String,
41+
expectedLeakCount: Int = 0,
42+
file: String = #file, line: UInt = #line,
43+
_ testFunction: @escaping () -> Void
44+
) {
45+
test(name, file: file, line: line) {
46+
withLeakChecking(expectedLeakCount: expectedLeakCount, file: file,
47+
line: line, testFunction)
48+
}
49+
}
50+
}
51+
52+
/// A type that tracks the number of live instances of a wrapped value type.
53+
///
54+
/// `Tracked<T>` is used to check for memory leaks in functions created via
55+
/// automatic differentiation.
56+
public struct Tracked<T> {
57+
fileprivate class Box {
58+
fileprivate var value : T
59+
init(_ value: T) {
60+
self.value = value
61+
_GlobalLeakCount.count += 1
62+
}
63+
deinit {
64+
_GlobalLeakCount.count -= 1
65+
}
66+
}
67+
private var handle: Box
68+
69+
@differentiable(where T : Differentiable, T == T.TangentVector)
70+
public init(_ value: T) {
71+
self.handle = Box(value)
72+
}
73+
74+
@differentiable(where T : Differentiable, T == T.TangentVector)
75+
public var value: T {
76+
get { handle.value }
77+
set { handle.value = newValue }
78+
}
79+
}
80+
81+
extension Tracked : ExpressibleByFloatLiteral where T : ExpressibleByFloatLiteral {
82+
public init(floatLiteral value: T.FloatLiteralType) {
83+
self.handle = Box(T(floatLiteral: value))
84+
}
85+
}
86+
87+
extension Tracked : CustomStringConvertible {
88+
public var description: String { return "Tracked(\(value))" }
89+
}
90+
91+
extension Tracked : ExpressibleByIntegerLiteral where T : ExpressibleByIntegerLiteral {
92+
public init(integerLiteral value: T.IntegerLiteralType) {
93+
self.handle = Box(T(integerLiteral: value))
94+
}
95+
}
96+
97+
extension Tracked : Comparable where T : Comparable {
98+
public static func < (lhs: Tracked, rhs: Tracked) -> Bool {
99+
return lhs.value < rhs.value
100+
}
101+
public static func <= (lhs: Tracked, rhs: Tracked) -> Bool {
102+
return lhs.value <= rhs.value
103+
}
104+
public static func > (lhs: Tracked, rhs: Tracked) -> Bool {
105+
return lhs.value > rhs.value
106+
}
107+
public static func >= (lhs: Tracked, rhs: Tracked) -> Bool {
108+
return lhs.value >= rhs.value
109+
}
110+
}
111+
112+
extension Tracked : AdditiveArithmetic where T : AdditiveArithmetic {
113+
public static var zero: Tracked { return Tracked(T.zero) }
114+
public static func + (lhs: Tracked, rhs: Tracked) -> Tracked {
115+
return Tracked(lhs.value + rhs.value)
116+
}
117+
public static func - (lhs: Tracked, rhs: Tracked) -> Tracked {
118+
return Tracked(lhs.value - rhs.value)
119+
}
120+
}
121+
122+
extension Tracked : Equatable where T : Equatable {
123+
public static func == (lhs: Tracked, rhs: Tracked) -> Bool {
124+
return lhs.value == rhs.value
125+
}
126+
}
127+
128+
extension Tracked : SignedNumeric & Numeric where T : SignedNumeric, T == T.Magnitude {
129+
public typealias Magnitude = Tracked<T.Magnitude>
130+
131+
public init?<U>(exactly source: U) where U : BinaryInteger {
132+
if let t = T(exactly: source) {
133+
self.init(t)
134+
}
135+
return nil
136+
}
137+
public var magnitude: Magnitude { return Magnitude(value.magnitude) }
138+
139+
public static func * (lhs: Tracked, rhs: Tracked) -> Tracked {
140+
return Tracked(lhs.value * rhs.value)
141+
}
142+
143+
public static func *= (lhs: inout Tracked, rhs: Tracked) {
144+
lhs = lhs * rhs
145+
}
146+
}
147+
148+
extension Tracked where T : FloatingPoint {
149+
public static func / (lhs: Tracked, rhs: Tracked) -> Tracked {
150+
return Tracked(lhs.value / rhs.value)
151+
}
152+
153+
public static func /= (lhs: inout Tracked, rhs: Tracked) {
154+
lhs = lhs / rhs
155+
}
156+
}
157+
158+
extension Tracked : Strideable where T : Strideable, T.Stride == T.Stride.Magnitude {
159+
public typealias Stride = Tracked<T.Stride>
160+
161+
public func distance(to other: Tracked) -> Stride {
162+
return Stride(value.distance(to: other.value))
163+
}
164+
public func advanced(by n: Stride) -> Tracked {
165+
return Tracked(value.advanced(by: n.value))
166+
}
167+
}
168+
169+
// For now, `T` must be restricted to trivial types (like `Float` or `Tensor`).
170+
extension Tracked : Differentiable where T : Differentiable, T == T.TangentVector {
171+
public typealias TangentVector = Tracked<T.TangentVector>
172+
}
173+
174+
extension Tracked where T : Differentiable, T == T.TangentVector {
175+
@usableFromInline
176+
@derivative(of: init)
177+
internal static func _vjpInit(_ value: T)
178+
-> (value: Self, pullback: (Self.TangentVector) -> (T.TangentVector)) {
179+
return (Tracked(value), { v in v.value })
180+
}
181+
182+
@usableFromInline
183+
@derivative(of: init)
184+
internal static func _jvpInit(_ value: T)
185+
-> (value: Self, differential: (T.TangentVector) -> (Self.TangentVector)) {
186+
return (Tracked(value), { v in Tracked(v) })
187+
}
188+
189+
@usableFromInline
190+
@derivative(of: value)
191+
internal func _vjpValue() -> (value: T, pullback: (T.TangentVector) -> Self.TangentVector) {
192+
return (value, { v in Tracked(v) })
193+
}
194+
195+
@usableFromInline
196+
@derivative(of: value)
197+
internal func _jvpValue() -> (value: T, differential: (Self.TangentVector) -> T.TangentVector) {
198+
return (value, { v in v.value })
199+
}
200+
}
201+
202+
extension Tracked where T : Differentiable, T == T.TangentVector {
203+
@usableFromInline
204+
@derivative(of: +)
205+
internal static func _vjpAdd(lhs: Self, rhs: Self)
206+
-> (value: Self, pullback: (Self) -> (Self, Self)) {
207+
return (lhs + rhs, { v in (v, v) })
208+
}
209+
210+
@usableFromInline
211+
@derivative(of: +)
212+
internal static func _jvpAdd(lhs: Self, rhs: Self)
213+
-> (value: Self, differential: (Self, Self) -> Self) {
214+
return (lhs + rhs, { $0 + $1 })
215+
}
216+
217+
@usableFromInline
218+
@derivative(of: -)
219+
internal static func _vjpSubtract(lhs: Self, rhs: Self)
220+
-> (value: Self, pullback: (Self) -> (Self, Self)) {
221+
return (lhs - rhs, { v in (v, .zero - v) })
222+
}
223+
224+
@usableFromInline
225+
@derivative(of: -)
226+
internal static func _jvpSubtract(lhs: Self, rhs: Self)
227+
-> (value: Self, differential: (Self, Self) -> Self) {
228+
return (lhs - rhs, { $0 - $1 })
229+
}
230+
}
231+
232+
extension Tracked where T : Differentiable & SignedNumeric, T == T.Magnitude,
233+
T == T.TangentVector {
234+
@usableFromInline
235+
@derivative(of: *)
236+
internal static func _vjpMultiply(lhs: Self, rhs: Self)
237+
-> (value: Self, pullback: (Self) -> (Self, Self)) {
238+
return (lhs * rhs, { v in (v * rhs, v * lhs) })
239+
}
240+
241+
@usableFromInline
242+
@derivative(of: *)
243+
internal static func _jvpMultiply(lhs: Self, rhs: Self)
244+
-> (value: Self, differential: (Self, Self) -> (Self)) {
245+
return (lhs * rhs, { (dx, dy) in dx * rhs + dy * lhs })
246+
}
247+
}
248+
249+
extension Tracked where T : Differentiable & FloatingPoint, T == T.TangentVector {
250+
@usableFromInline
251+
@derivative(of: /)
252+
internal static func _vjpDivide(lhs: Self, rhs: Self)
253+
-> (value: Self, pullback: (Self) -> (Self, Self)) {
254+
return (lhs / rhs, { v in (v / rhs, -lhs / (rhs * rhs) * v) })
255+
}
256+
257+
@usableFromInline
258+
@derivative(of: /)
259+
internal static func _jvpDivide(lhs: Self, rhs: Self)
260+
-> (value: Self, differential: (Self, Self) -> (Self)) {
261+
return (lhs / rhs, { (dx, dy) in dx / rhs - lhs / (rhs * rhs) * dy })
262+
}
263+
}
264+
265+
// Differential operators for `Tracked<T>`.
266+
267+
public func gradient<T, R: FloatingPoint>(
268+
at x: T, in f: @differentiable (T) -> Tracked<R>
269+
) -> T.TangentVector where R.TangentVector == R {
270+
return pullback(at: x, in: f)(1)
271+
}
272+
273+
public func gradient<T, U, R: FloatingPoint>(
274+
at x: T, _ y: U, in f: @differentiable (T, U) -> Tracked<R>
275+
) -> (T.TangentVector, U.TangentVector) where R.TangentVector == R {
276+
return pullback(at: x, y, in: f)(1)
277+
}
278+
279+
public func derivative<T: FloatingPoint, R>(
280+
at x: Tracked<T>, in f: @differentiable (Tracked<T>) -> R
281+
) -> R.TangentVector where T.TangentVector == T {
282+
return differential(at: x, in: f)(1)
283+
}
284+
285+
public func derivative<T: FloatingPoint, U: FloatingPoint, R>(
286+
at x: Tracked<T>, _ y: Tracked<U>,
287+
in f: @differentiable (Tracked<T>, Tracked<U>) -> R
288+
) -> R.TangentVector where T.TangentVector == T, U.TangentVector == U {
289+
return differential(at: x, y, in: f)(1, 1)
290+
}
291+
292+
public func valueWithGradient<T, R: FloatingPoint>(
293+
at x: T, in f: @differentiable (T) -> Tracked<R>
294+
) -> (value: Tracked<R>, gradient: T.TangentVector) {
295+
let (y, pullback) = valueWithPullback(at: x, in: f)
296+
return (y, pullback(1))
297+
}
298+
299+
public func valueWithDerivative<T: FloatingPoint, R>(
300+
at x: Tracked<T>, in f: @differentiable (Tracked<T>) -> R
301+
) -> (value: R, derivative: R.TangentVector) {
302+
let (y, differential) = valueWithDifferential(at: x, in: f)
303+
return (y, differential(1))
304+
}

0 commit comments

Comments
 (0)