1
1
// RUN: %target-run-simple-swift
2
2
3
3
import StdlibUnittest
4
+ import DifferentiationUnittest
4
5
5
6
var ProtocolRequirementAutodiffTests = TestSuite ( " ProtocolRequirementAutodiff " )
6
7
7
8
// MARK: - Func requirements.
8
9
9
10
protocol DiffReq : Differentiable {
10
11
@differentiable ( wrt: ( self , x) )
11
- func f( _ x: Float ) -> Float
12
+ func f( _ x: Tracked < Float > ) -> Tracked < Float >
12
13
}
13
14
14
15
extension DiffReq where TangentVector : AdditiveArithmetic {
15
16
@inline ( never) // Prevent specialization, to test all witness code.
16
- func gradF( at x: Float ) -> ( Self . TangentVector , Float ) {
17
+ func gradF( at x: Tracked < Float > ) -> ( Self . TangentVector , Tracked < Float > ) {
17
18
return ( valueWithPullback ( at: x) { s, x in s. f ( x) } ) . 1 ( 1 )
18
19
}
19
20
}
@@ -22,27 +23,27 @@ struct Quadratic : DiffReq, AdditiveArithmetic {
22
23
typealias TangentVector = Quadratic
23
24
24
25
@differentiable
25
- let a : Float
26
+ let a : Tracked < Float >
26
27
27
28
@differentiable
28
- let b : Float
29
+ let b : Tracked < Float >
29
30
30
31
@differentiable
31
- let c : Float
32
+ let c : Tracked < Float >
32
33
33
- init ( _ a: Float , _ b: Float , _ c: Float ) {
34
+ init ( _ a: Tracked < Float > , _ b: Tracked < Float > , _ c: Tracked < Float > ) {
34
35
self . a = a
35
36
self . b = b
36
37
self . c = c
37
38
}
38
39
39
40
@differentiable ( wrt: ( self , x) )
40
- func f( _ x: Float ) -> Float {
41
+ func f( _ x: Tracked < Float > ) -> Tracked < Float > {
41
42
return a * x * x + b * x + c
42
43
}
43
44
}
44
45
45
- ProtocolRequirementAutodiffTests . test ( " func " ) {
46
+ ProtocolRequirementAutodiffTests . testWithLeakChecking ( " func " ) {
46
47
expectEqual ( ( Quadratic ( 0 , 0 , 1 ) , 12 ) , Quadratic ( 11 , 12 , 13 ) . gradF ( at: 0 ) )
47
48
expectEqual ( ( Quadratic ( 1 , 1 , 1 ) , 2 * 11 + 12 ) ,
48
49
Quadratic ( 11 , 12 , 13 ) . gradF ( at: 1 ) )
@@ -54,48 +55,48 @@ ProtocolRequirementAutodiffTests.test("func") {
54
55
55
56
protocol FunctionsOfX : Differentiable {
56
57
@differentiable
57
- init ( x: Float )
58
+ init ( x: Tracked < Float > )
58
59
59
60
@differentiable
60
- var x : Float { get }
61
+ var x : Tracked < Float > { get }
61
62
62
63
@differentiable
63
- var y : Float { get }
64
+ var y : Tracked < Float > { get }
64
65
65
66
@differentiable
66
- var z : Float { get }
67
+ var z : Tracked < Float > { get }
67
68
68
69
@differentiable
69
- subscript( ) -> Float { get }
70
+ subscript( ) -> Tracked < Float > { get }
70
71
}
71
72
72
73
struct TestFunctionsOfX : FunctionsOfX {
73
74
@differentiable
74
- init ( x: Float ) {
75
+ init ( x: Tracked < Float > ) {
75
76
self . x = x
76
77
self . y = x * x
77
78
}
78
79
79
80
/// x = x
80
- var x : Float
81
+ var x : Tracked < Float >
81
82
82
83
/// y = x * x
83
- var y : Float
84
+ var y : Tracked < Float >
84
85
85
86
/// z = x * x + x
86
- var z : Float {
87
+ var z : Tracked < Float > {
87
88
return y + x
88
89
}
89
90
90
91
@differentiable
91
- subscript( ) -> Float {
92
+ subscript( ) -> Tracked < Float > {
92
93
return z
93
94
}
94
95
}
95
96
96
97
@inline ( never) // Prevent specialization, to test all witness code.
97
- func derivatives< F: FunctionsOfX > ( at x: Float , in: F . Type )
98
- -> ( Float , Float , Float , Float )
98
+ func derivatives< F: FunctionsOfX > ( at x: Tracked < Float > , in: F . Type )
99
+ -> ( Tracked < Float > , Tracked < Float > , Tracked < Float > , Tracked < Float > )
99
100
{
100
101
let dxdx = gradient ( at: x) { x in F ( x: x) . x }
101
102
let dydx = gradient ( at: x) { x in F ( x: x) . y }
@@ -104,7 +105,7 @@ func derivatives<F: FunctionsOfX>(at x: Float, in: F.Type)
104
105
return ( dxdx, dydx, dzdx, dsubscriptdx)
105
106
}
106
107
107
- ProtocolRequirementAutodiffTests . test ( " constructor, accessor, subscript " ) {
108
+ ProtocolRequirementAutodiffTests . testWithLeakChecking ( " constructor, accessor, subscript " ) {
108
109
expectEqual (
109
110
( 1.0 , 4.0 , 5.0 , 5.0 ) ,
110
111
derivatives ( at: 2.0 , in: TestFunctionsOfX . self) )
@@ -114,11 +115,11 @@ ProtocolRequirementAutodiffTests.test("constructor, accessor, subscript") {
114
115
115
116
protocol P : Differentiable {
116
117
@differentiable ( wrt: ( x, y) )
117
- func foo( _ x: Float , _ y: Double ) -> Float
118
+ func foo( _ x: Tracked < Float > , _ y: Double ) -> Tracked < Float >
118
119
}
119
120
struct S : P {
120
121
@differentiable ( wrt: ( x, y) )
121
- func foo( _ x: Float , _ y: Double ) -> Float {
122
+ func foo( _ x: Tracked < Float > , _ y: Double ) -> Tracked < Float > {
122
123
return x
123
124
}
124
125
}
@@ -127,23 +128,24 @@ struct S : P {
127
128
128
129
public protocol Distribution {
129
130
associatedtype Value
130
- func logProbability( of value: Value ) -> Float
131
+ func logProbability( of value: Value ) -> Tracked < Float >
131
132
}
132
133
133
134
public protocol DifferentiableDistribution : Differentiable , Distribution {
134
135
@differentiable ( wrt: self )
135
- func logProbability( of value: Value ) -> Float
136
+ func logProbability( of value: Value ) -> Tracked < Float >
136
137
}
137
138
138
139
struct Foo : DifferentiableDistribution {
139
140
@differentiable ( wrt: self )
140
- func logProbability( of value: Float ) -> Float {
141
+ func logProbability( of value: Tracked < Float > ) -> Tracked < Float > {
141
142
. zero
142
143
}
143
144
}
144
145
145
146
@differentiable
146
- func blah< T: DifferentiableDistribution > ( _ x: T ) -> Float where T. Value: AdditiveArithmetic {
147
+ func blah< T: DifferentiableDistribution > ( _ x: T ) -> Tracked < Float >
148
+ where T. Value: AdditiveArithmetic {
147
149
x. logProbability ( of: . zero)
148
150
}
149
151
@@ -152,29 +154,29 @@ public protocol DoubleDifferentiableDistribution: DifferentiableDistribution
152
154
where Value: Differentiable {
153
155
@differentiable ( wrt: self )
154
156
@differentiable ( wrt: ( self , value) )
155
- func logProbability( of value: Value ) -> Float
157
+ func logProbability( of value: Value ) -> Tracked < Float >
156
158
}
157
159
158
160
@differentiable
159
- func blah2< T: DoubleDifferentiableDistribution > ( _ x: T , _ value: T . Value ) -> Float
161
+ func blah2< T: DoubleDifferentiableDistribution > ( _ x: T , _ value: T . Value ) -> Tracked < Float >
160
162
where T. Value: AdditiveArithmetic {
161
163
x. logProbability ( of: value)
162
164
}
163
165
164
166
protocol DifferentiableFoo {
165
167
associatedtype T : Differentiable
166
168
@differentiable ( wrt: x)
167
- func foo( _ x: T ) -> Float
169
+ func foo( _ x: T ) -> Tracked < Float >
168
170
}
169
171
170
172
protocol MoreDifferentiableFoo : Differentiable , DifferentiableFoo {
171
173
@differentiable ( wrt: ( self , x) )
172
- func foo( _ x: T ) -> Float
174
+ func foo( _ x: T ) -> Tracked < Float >
173
175
}
174
176
175
177
struct MoreDifferentiableFooStruct : MoreDifferentiableFoo {
176
178
@differentiable ( wrt: ( self , x) )
177
- func foo( _ x: Float ) -> Float {
179
+ func foo( _ x: Tracked < Float > ) -> Tracked < Float > {
178
180
x
179
181
}
180
182
}
0 commit comments