Skip to content

Commit b575448

Browse files
bgogulrxwei
authored andcommitted
Rewrite some of the AD tests with Tracked<Float> (#27733)
First of the several PRs that will enable leak checking for the AD tests. [TF-895](https://bugs.swift.org/browse/TF-895)
1 parent 7fdc5d0 commit b575448

7 files changed

+97
-86
lines changed

test/AutoDiff/currying.swift

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,23 @@
11
// RUN: %target-run-simple-swift
22

33
import StdlibUnittest
4+
import DifferentiationUnittest
45

56
var CurryingAutodiffTests = TestSuite("CurryingAutodiff")
67

7-
CurryingAutodiffTests.test("StructMember") {
8+
CurryingAutodiffTests.testWithLeakChecking("StructMember") {
89
struct A {
910
@differentiable(wrt: (value))
10-
func v(_ value: Float) -> Float { return value * value }
11+
func v(_ value: Tracked<Float>) -> Tracked<Float> { return value * value }
1112
}
1213

1314
let a = A()
14-
// This implicitly constructs a function (A) -> (Float) -> Float
15+
// This implicitly constructs a function (A) -> (Tracked<Float>) -> Tracked<Float>
1516
// which gets called with a:
16-
let g: @differentiable (Float) -> Float = a.v
17+
let g: @differentiable (Tracked<Float>) -> Tracked<Float> = a.v
1718

1819

19-
expectEqual(6.0, Float(3.0).gradient(in: g))
20+
expectEqual(6.0, Tracked<Float>(3.0).gradient(in: g))
2021
}
2122

2223
runAllTests()

test/AutoDiff/derivative_registration.swift

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,70 +2,71 @@
22
// REQUIRES: executable_test
33

44
import StdlibUnittest
5+
import DifferentiationUnittest
56

67
var DerivativeRegistrationTests = TestSuite("DerivativeRegistration")
78

89
@_semantics("autodiff.opaque")
9-
func unary(x: Float) -> Float {
10+
func unary(x: Tracked<Float>) -> Tracked<Float> {
1011
return x
1112
}
1213
@differentiating(unary)
13-
func _vjpUnary(x: Float) -> (value: Float, pullback: (Float) -> Float) {
14+
func _vjpUnary(x: Tracked<Float>) -> (value: Tracked<Float>, pullback: (Tracked<Float>) -> Tracked<Float>) {
1415
return (value: x, pullback: { v in v })
1516
}
16-
DerivativeRegistrationTests.test("UnaryFreeFunction") {
17+
DerivativeRegistrationTests.testWithLeakChecking("UnaryFreeFunction") {
1718
expectEqual(1, gradient(at: 3.0, in: unary))
1819
}
1920

2021
@_semantics("autodiff.opaque")
21-
func multiply(_ x: Float, _ y: Float) -> Float {
22+
func multiply(_ x: Tracked<Float>, _ y: Tracked<Float>) -> Tracked<Float> {
2223
return x * y
2324
}
2425
@differentiating(multiply)
25-
func _vjpMultiply(_ x: Float, _ y: Float)
26-
-> (value: Float, pullback: (Float) -> (Float, Float)) {
26+
func _vjpMultiply(_ x: Tracked<Float>, _ y: Tracked<Float>)
27+
-> (value: Tracked<Float>, pullback: (Tracked<Float>) -> (Tracked<Float>, Tracked<Float>)) {
2728
return (x * y, { v in (v * y, v * x) })
2829
}
29-
DerivativeRegistrationTests.test("BinaryFreeFunction") {
30+
DerivativeRegistrationTests.testWithLeakChecking("BinaryFreeFunction") {
3031
expectEqual((3.0, 2.0), gradient(at: 2.0, 3.0, in: { x, y in multiply(x, y) }))
3132
}
3233

3334
struct Wrapper : Differentiable {
34-
var float: Float
35+
var float: Tracked<Float>
3536
}
3637

3738
extension Wrapper {
3839
@_semantics("autodiff.opaque")
39-
static func multiply(_ x: Float, _ y: Float) -> Float {
40+
static func multiply(_ x: Tracked<Float>, _ y: Tracked<Float>) -> Tracked<Float> {
4041
return x * y
4142
}
4243

4344
@differentiating(multiply)
44-
static func _vjpMultiply(_ x: Float, _ y: Float)
45-
-> (value: Float, pullback: (Float) -> (Float, Float)) {
45+
static func _vjpMultiply(_ x: Tracked<Float>, _ y: Tracked<Float>)
46+
-> (value: Tracked<Float>, pullback: (Tracked<Float>) -> (Tracked<Float>, Tracked<Float>)) {
4647
return (x * y, { v in (v * y, v * x) })
4748
}
4849
}
49-
DerivativeRegistrationTests.test("StaticMethod") {
50+
DerivativeRegistrationTests.testWithLeakChecking("StaticMethod") {
5051
expectEqual((3.0, 2.0), gradient(at: 2.0, 3.0, in: { x, y in Wrapper.multiply(x, y) }))
5152
}
5253

5354
extension Wrapper {
5455
@_semantics("autodiff.opaque")
55-
func multiply(_ x: Float) -> Float {
56+
func multiply(_ x: Tracked<Float>) -> Tracked<Float> {
5657
return float * x
5758
}
5859

5960
@differentiating(multiply)
60-
func _vjpMultiply(_ x: Float)
61-
-> (value: Float, pullback: (Float) -> (Wrapper.TangentVector, Float)) {
61+
func _vjpMultiply(_ x: Tracked<Float>)
62+
-> (value: Tracked<Float>, pullback: (Tracked<Float>) -> (Wrapper.TangentVector, Tracked<Float>)) {
6263
return (float * x, { v in
6364
(Wrapper.TangentVector(float: v * x), v * self.float)
6465
})
6566
}
6667
}
67-
DerivativeRegistrationTests.test("InstanceMethod") {
68-
let x: Float = 2
68+
DerivativeRegistrationTests.testWithLeakChecking("InstanceMethod") {
69+
let x: Tracked<Float> = 2
6970
let wrapper = Wrapper(float: 3)
7071
let (𝛁wrapper, 𝛁x) = wrapper.gradient(at: x) { wrapper, x in wrapper.multiply(x) }
7172
expectEqual(Wrapper.TangentVector(float: 2), 𝛁wrapper)

test/AutoDiff/protocol_requirement_autodiff.swift

Lines changed: 34 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,20 @@
11
// RUN: %target-run-simple-swift
22

33
import StdlibUnittest
4+
import DifferentiationUnittest
45

56
var ProtocolRequirementAutodiffTests = TestSuite("ProtocolRequirementAutodiff")
67

78
// MARK: - Func requirements.
89

910
protocol DiffReq : Differentiable {
1011
@differentiable(wrt: (self, x))
11-
func f(_ x: Float) -> Float
12+
func f(_ x: Tracked<Float>) -> Tracked<Float>
1213
}
1314

1415
extension DiffReq where TangentVector : AdditiveArithmetic {
1516
@inline(never) // Prevent specialization, to test all witness code.
16-
func gradF(at x: Float) -> (Self.TangentVector, Float) {
17+
func gradF(at x: Tracked<Float>) -> (Self.TangentVector, Tracked<Float>) {
1718
return (valueWithPullback(at: x) { s, x in s.f(x) }).1(1)
1819
}
1920
}
@@ -22,27 +23,27 @@ struct Quadratic : DiffReq, AdditiveArithmetic {
2223
typealias TangentVector = Quadratic
2324

2425
@differentiable
25-
let a: Float
26+
let a: Tracked<Float>
2627

2728
@differentiable
28-
let b: Float
29+
let b: Tracked<Float>
2930

3031
@differentiable
31-
let c: Float
32+
let c: Tracked<Float>
3233

33-
init(_ a: Float, _ b: Float, _ c: Float) {
34+
init(_ a: Tracked<Float>, _ b: Tracked<Float>, _ c: Tracked<Float>) {
3435
self.a = a
3536
self.b = b
3637
self.c = c
3738
}
3839

3940
@differentiable(wrt: (self, x))
40-
func f(_ x: Float) -> Float {
41+
func f(_ x: Tracked<Float>) -> Tracked<Float> {
4142
return a * x * x + b * x + c
4243
}
4344
}
4445

45-
ProtocolRequirementAutodiffTests.test("func") {
46+
ProtocolRequirementAutodiffTests.testWithLeakChecking("func") {
4647
expectEqual((Quadratic(0, 0, 1), 12), Quadratic(11, 12, 13).gradF(at: 0))
4748
expectEqual((Quadratic(1, 1, 1), 2 * 11 + 12),
4849
Quadratic(11, 12, 13).gradF(at: 1))
@@ -54,48 +55,48 @@ ProtocolRequirementAutodiffTests.test("func") {
5455

5556
protocol FunctionsOfX: Differentiable {
5657
@differentiable
57-
init(x: Float)
58+
init(x: Tracked<Float>)
5859

5960
@differentiable
60-
var x: Float { get }
61+
var x: Tracked<Float> { get }
6162

6263
@differentiable
63-
var y: Float { get }
64+
var y: Tracked<Float> { get }
6465

6566
@differentiable
66-
var z: Float { get }
67+
var z: Tracked<Float> { get }
6768

6869
@differentiable
69-
subscript() -> Float { get }
70+
subscript() -> Tracked<Float> { get }
7071
}
7172

7273
struct TestFunctionsOfX: FunctionsOfX {
7374
@differentiable
74-
init(x: Float) {
75+
init(x: Tracked<Float>) {
7576
self.x = x
7677
self.y = x * x
7778
}
7879

7980
/// x = x
80-
var x: Float
81+
var x: Tracked<Float>
8182

8283
/// y = x * x
83-
var y: Float
84+
var y: Tracked<Float>
8485

8586
/// z = x * x + x
86-
var z: Float {
87+
var z: Tracked<Float> {
8788
return y + x
8889
}
8990

9091
@differentiable
91-
subscript() -> Float {
92+
subscript() -> Tracked<Float> {
9293
return z
9394
}
9495
}
9596

9697
@inline(never) // Prevent specialization, to test all witness code.
97-
func derivatives<F: FunctionsOfX>(at x: Float, in: F.Type)
98-
-> (Float, Float, Float, Float)
98+
func derivatives<F: FunctionsOfX>(at x: Tracked<Float>, in: F.Type)
99+
-> (Tracked<Float>, Tracked<Float>, Tracked<Float>, Tracked<Float>)
99100
{
100101
let dxdx = gradient(at: x) { x in F(x: x).x }
101102
let dydx = gradient(at: x) { x in F(x: x).y }
@@ -104,7 +105,7 @@ func derivatives<F: FunctionsOfX>(at x: Float, in: F.Type)
104105
return (dxdx, dydx, dzdx, dsubscriptdx)
105106
}
106107

107-
ProtocolRequirementAutodiffTests.test("constructor, accessor, subscript") {
108+
ProtocolRequirementAutodiffTests.testWithLeakChecking("constructor, accessor, subscript") {
108109
expectEqual(
109110
(1.0, 4.0, 5.0, 5.0),
110111
derivatives(at: 2.0, in: TestFunctionsOfX.self))
@@ -114,11 +115,11 @@ ProtocolRequirementAutodiffTests.test("constructor, accessor, subscript") {
114115

115116
protocol P : Differentiable {
116117
@differentiable(wrt: (x, y))
117-
func foo(_ x: Float, _ y: Double) -> Float
118+
func foo(_ x: Tracked<Float>, _ y: Double) -> Tracked<Float>
118119
}
119120
struct S : P {
120121
@differentiable(wrt: (x, y))
121-
func foo(_ x: Float, _ y: Double) -> Float {
122+
func foo(_ x: Tracked<Float>, _ y: Double) -> Tracked<Float> {
122123
return x
123124
}
124125
}
@@ -127,23 +128,24 @@ struct S : P {
127128

128129
public protocol Distribution {
129130
associatedtype Value
130-
func logProbability(of value: Value) -> Float
131+
func logProbability(of value: Value) -> Tracked<Float>
131132
}
132133

133134
public protocol DifferentiableDistribution: Differentiable, Distribution {
134135
@differentiable(wrt: self)
135-
func logProbability(of value: Value) -> Float
136+
func logProbability(of value: Value) -> Tracked<Float>
136137
}
137138

138139
struct Foo: DifferentiableDistribution {
139140
@differentiable(wrt: self)
140-
func logProbability(of value: Float) -> Float {
141+
func logProbability(of value: Tracked<Float>) -> Tracked<Float> {
141142
.zero
142143
}
143144
}
144145

145146
@differentiable
146-
func blah<T: DifferentiableDistribution>(_ x: T) -> Float where T.Value: AdditiveArithmetic {
147+
func blah<T: DifferentiableDistribution>(_ x: T) -> Tracked<Float>
148+
where T.Value: AdditiveArithmetic {
147149
x.logProbability(of: .zero)
148150
}
149151

@@ -152,29 +154,29 @@ public protocol DoubleDifferentiableDistribution: DifferentiableDistribution
152154
where Value: Differentiable {
153155
@differentiable(wrt: self)
154156
@differentiable(wrt: (self, value))
155-
func logProbability(of value: Value) -> Float
157+
func logProbability(of value: Value) -> Tracked<Float>
156158
}
157159

158160
@differentiable
159-
func blah2<T: DoubleDifferentiableDistribution>(_ x: T, _ value: T.Value) -> Float
161+
func blah2<T: DoubleDifferentiableDistribution>(_ x: T, _ value: T.Value) -> Tracked<Float>
160162
where T.Value: AdditiveArithmetic {
161163
x.logProbability(of: value)
162164
}
163165

164166
protocol DifferentiableFoo {
165167
associatedtype T: Differentiable
166168
@differentiable(wrt: x)
167-
func foo(_ x: T) -> Float
169+
func foo(_ x: T) -> Tracked<Float>
168170
}
169171

170172
protocol MoreDifferentiableFoo: Differentiable, DifferentiableFoo {
171173
@differentiable(wrt: (self, x))
172-
func foo(_ x: T) -> Float
174+
func foo(_ x: T) -> Tracked<Float>
173175
}
174176

175177
struct MoreDifferentiableFooStruct: MoreDifferentiableFoo {
176178
@differentiable(wrt: (self, x))
177-
func foo(_ x: Float) -> Float {
179+
func foo(_ x: Tracked<Float>) -> Tracked<Float> {
178180
x
179181
}
180182
}

test/AutoDiff/repeated_calls.swift

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,15 @@
22
// REQUIRES: executable_test
33

44
import StdlibUnittest
5+
import DifferentiationUnittest
56

67
var RepeatedCallsTests = TestSuite("RepeatedCalls")
78

8-
RepeatedCallsTests.test("Repeat") {
9-
func mul2(_ x: Float) -> Float {
9+
RepeatedCallsTests.testWithLeakChecking("Repeat") {
10+
func mul2(_ x: Tracked<Float>) -> Tracked<Float> {
1011
return 2 * x
1112
}
12-
func mul4(_ x: Float) -> Float {
13+
func mul4(_ x: Tracked<Float>) -> Tracked<Float> {
1314
return mul2(mul2(x))
1415
}
1516
expectEqual(4, gradient(at: 0, in: mul4))

test/AutoDiff/separate_tangent_type.swift

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,40 +7,41 @@ import Darwin.C
77
#else
88
import Glibc
99
#endif
10+
import DifferentiationUnittest
1011

1112
var SeparateTangentTypeTests = TestSuite("SeparateTangentType")
1213

1314
struct DifferentiableSubset : Differentiable {
1415
@differentiable(wrt: self)
15-
var w: Float
16+
var w: Tracked<Float>
1617
@differentiable(wrt: self)
17-
var b: Float
18+
var b: Tracked<Float>
1819
@noDerivative var flag: Bool
1920

2021
struct TangentVector : Differentiable, AdditiveArithmetic {
2122
typealias TangentVector = DifferentiableSubset.TangentVector
22-
var w: Float
23-
var b: Float
23+
var w: Tracked<Float>
24+
var b: Tracked<Float>
2425
}
2526
mutating func move(along v: TangentVector) {
2627
w.move(along: v.w)
2728
b.move(along: v.b)
2829
}
2930
}
3031

31-
SeparateTangentTypeTests.test("Trivial") {
32+
SeparateTangentTypeTests.testWithLeakChecking("Trivial") {
3233
let x = DifferentiableSubset(w: 0, b: 1, flag: false)
3334
let pb = pullback(at: x) { x in x }
3435
expectEqual(pb(DifferentiableSubset.TangentVector.zero), DifferentiableSubset.TangentVector.zero)
3536
}
3637

37-
SeparateTangentTypeTests.test("Initialization") {
38+
SeparateTangentTypeTests.testWithLeakChecking("Initialization") {
3839
let x = DifferentiableSubset(w: 0, b: 1, flag: false)
3940
let pb = pullback(at: x) { x in DifferentiableSubset(w: 1, b: 2, flag: true) }
4041
expectEqual(pb(DifferentiableSubset.TangentVector.zero), DifferentiableSubset.TangentVector.zero)
4142
}
4243

43-
SeparateTangentTypeTests.test("SomeArithmetics") {
44+
SeparateTangentTypeTests.testWithLeakChecking("SomeArithmetics") {
4445
let x = DifferentiableSubset(w: 0, b: 1, flag: false)
4546
let pb = pullback(at: x) { x in DifferentiableSubset(w: x.w * x.w, b: x.b * x.b, flag: true) }
4647
expectEqual(pb(DifferentiableSubset.TangentVector.zero), DifferentiableSubset.TangentVector.zero)

0 commit comments

Comments
 (0)