11
11
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
12
12
//
13
13
//===----------------------------------------------------------------------===//
14
+ // This file defines derivatives for tgmath functions.
15
+ //===----------------------------------------------------------------------===//
14
16
15
17
@usableFromInline
16
- @derivative ( of: sqrt)
17
- func _vjpSqrt< T: FloatingPoint & Differentiable > (
18
- _ x: T
19
- ) -> ( value: T , pullback: ( T ) -> T ) where T == T . TangentVector {
20
- let value = sqrt ( x)
21
- return ( value, { v in v / ( 2 * value) } )
18
+ @derivative ( of: fma)
19
+ func _jvpFma< T: FloatingPoint & Differentiable > (
20
+ _ x: T ,
21
+ _ y: T ,
22
+ _ z: T
23
+ ) -> ( value: T , differential: ( T , T , T ) -> T ) where T == T . TangentVector {
24
+ return ( fma ( x, y, z) , { ( dx, dy, dz) in dx * y + dy * x + dz } )
22
25
}
23
26
24
27
@usableFromInline
@@ -31,6 +34,18 @@ func _vjpFma<T: FloatingPoint & Differentiable> (
31
34
return ( fma ( x, y, z) , { v in ( v * y, v * x, v) } )
32
35
}
33
36
37
+ @usableFromInline
38
+ @derivative ( of: remainder)
39
+ func _jvpRemainder< T: FloatingPoint & Differentiable > (
40
+ _ x: T ,
41
+ _ y: T
42
+ ) -> ( value: T , differential: ( T , T ) -> T ) where T == T . TangentVector {
43
+ fatalError ( """
44
+ Unimplemented JVP for 'remainder(_:)'. \
45
+ https://bugs.swift.org/browse/TF-1108 tracks this issue
46
+ """ )
47
+ }
48
+
34
49
@usableFromInline
35
50
@derivative ( of: remainder)
36
51
func _vjpRemainder< T: FloatingPoint & Differentiable > (
@@ -40,6 +55,18 @@ func _vjpRemainder<T: FloatingPoint & Differentiable> (
40
55
return ( remainder ( x, y) , { v in ( v, - v * ( ( x / y) . rounded ( . toNearestOrEven) ) ) } )
41
56
}
42
57
58
+ @usableFromInline
59
+ @derivative ( of: fmod)
60
+ func _jvpFmod< T: FloatingPoint & Differentiable > (
61
+ _ x: T ,
62
+ _ y: T
63
+ ) -> ( value: T , differential: ( T , T ) -> T ) where T == T . TangentVector {
64
+ fatalError ( """
65
+ Unimplemented JVP for 'fmod(_:)'. \
66
+ https://bugs.swift.org/browse/TF-1108 tracks this issue
67
+ """ )
68
+ }
69
+
43
70
@usableFromInline
44
71
@derivative ( of: fmod)
45
72
func _vjpFmod< T: FloatingPoint & Differentiable > (
@@ -49,173 +76,188 @@ func _vjpFmod<T: FloatingPoint & Differentiable> (
49
76
return ( fmod ( x, y) , { v in ( v, - v * ( ( x / y) . rounded ( . towardZero) ) ) } )
50
77
}
51
78
79
+ % for derivative_kind in [ 'jvp', 'vjp'] :
80
+ % linear_map_kind = 'differential' if derivative_kind == 'jvp' else 'pullback'
81
+ @usableFromInline
82
+ @derivative ( of: sqrt)
83
+ func _${ derivative_kind} Sqrt< T: FloatingPoint & Differentiable> (
84
+ _ x: T
85
+ ) - > ( value: T, ${ linear_map_kind} : ( T) - > T) where T == T . TangentVector {
86
+ let value = sqrt ( x)
87
+ return ( value, { v in v / ( 2 * value) } )
88
+ }
89
+
52
90
@usableFromInline
53
91
@derivative ( of: ceil)
54
- func _vjpCeil < T: FloatingPoint & Differentiable > (
92
+ func _$ { derivative_kind } Ceil < T: FloatingPoint & Differentiable> (
55
93
_ x: T
56
- ) -> ( value: T , pullback : ( T ) -> T ) where T == T . TangentVector {
94
+ ) - > ( value: T, $ { linear_map_kind } : ( T) - > T) where T == T . TangentVector {
57
95
return ( ceil ( x) , { v in 0 } )
58
96
}
59
97
60
98
@usableFromInline
61
99
@derivative ( of: floor)
62
- func _vjpFloor < T: FloatingPoint & Differentiable > (
100
+ func _$ { derivative_kind } Floor < T: FloatingPoint & Differentiable> (
63
101
_ x: T
64
- ) -> ( value: T , pullback : ( T ) -> T ) where T == T . TangentVector {
102
+ ) - > ( value: T, $ { linear_map_kind } : ( T) - > T) where T == T . TangentVector {
65
103
return ( floor ( x) , { v in 0 } )
66
104
}
67
105
68
106
@usableFromInline
69
107
@derivative ( of: round)
70
- func _vjpRound < T: FloatingPoint & Differentiable > (
108
+ func _$ { derivative_kind } Round < T: FloatingPoint & Differentiable> (
71
109
_ x: T
72
- ) -> ( value: T , pullback : ( T ) -> T ) where T == T . TangentVector {
110
+ ) - > ( value: T, $ { linear_map_kind } : ( T) - > T) where T == T . TangentVector {
73
111
return ( round ( x) , { v in 0 } )
74
112
}
75
113
76
114
@usableFromInline
77
115
@derivative ( of: trunc)
78
- func _vjpTrunc < T: FloatingPoint & Differentiable > (
116
+ func _$ { derivative_kind } Trunc < T: FloatingPoint & Differentiable> (
79
117
_ x: T
80
- ) -> ( value: T , pullback : ( T ) -> T ) where T == T . TangentVector {
118
+ ) - > ( value: T, $ { linear_map_kind } : ( T) - > T) where T == T . TangentVector {
81
119
return ( trunc ( x) , { v in 0 } )
82
120
}
121
+ % end # for derivative_kind in [ 'jvp', 'vjp'] :
83
122
84
- % for T in [ 'Float', 'Double', 'Float80 '] :
85
- % if T == 'Float80 ':
123
+ % for derivative_kind in [ 'jvp', 'vjp'] :
124
+ % linear_map_kind = 'differential' if derivative_kind == 'jvp' else 'pullback'
125
+ % for T in [ 'Float', 'Double', 'Float80 '] :
126
+ % if T == 'Float80 ':
86
127
#if !(os(Windows) || os(Android)) && (arch(i386) || arch(x86_64))
87
- % end
128
+ % end
88
129
@inlinable
89
130
@derivative ( of: exp)
90
- func _vjpExp ( _ x: ${ T} ) - > ( value: ${ T} , pullback : ( ${ T} ) - > ${ T} ) {
131
+ func _$ { derivative_kind } Exp ( _ x: ${ T} ) - > ( value: ${ T} , $ { linear_map_kind } : ( ${ T} ) - > ${ T} ) {
91
132
let value = exp ( x)
92
133
return ( value, { v in value * v } )
93
134
}
94
135
95
136
@inlinable
96
137
@derivative ( of: exp2)
97
- func _vjpExp 2 ( _ x: ${ T} ) - > ( value: ${ T} , pullback : ( ${ T} ) - > ${ T} ) {
138
+ func _$ { derivative_kind } Exp2 ( _ x: ${ T} ) - > ( value: ${ T} , $ { linear_map_kind } : ( ${ T} ) - > ${ T} ) {
98
139
let value = exp2 ( x)
99
140
return ( value, { v in v * ${ T} ( M_LN2) * value } )
100
141
}
101
142
102
143
@inlinable
103
144
@derivative( of: log)
104
- func _vjpLog ( _ x: ${ T} ) -> ( value: ${ T} , pullback : ( ${ T} ) -> ${ T} ) {
145
+ func _$ { derivative_kind } Log ( _ x: ${ T} ) - > ( value: ${ T} , $ { linear_map_kind } : ( ${ T} ) - > ${ T} ) {
105
146
return ( log ( x) , { v in v / x } )
106
147
}
107
148
108
149
@inlinable
109
150
@derivative( of: log10)
110
- func _vjpLog10 ( _ x: ${ T} ) -> ( value: ${ T} , pullback : ( ${ T} ) -> ${ T} ) {
151
+ func _$ { derivative_kind } Log10 ( _ x: ${ T} ) - > ( value: ${ T} , $ { linear_map_kind } : ( ${ T} ) - > ${ T} ) {
111
152
return ( log10 ( x) , { v in v * ${ T} ( M_LOG10E) / x } )
112
153
}
113
154
114
155
@inlinable
115
156
@derivative( of: log2)
116
- func _vjpLog 2 ( _ x: ${ T} ) -> ( value: ${ T} , pullback : ( ${ T} ) -> ${ T} ) {
157
+ func _$ { derivative_kind } Log2 ( _ x: ${ T} ) - > ( value: ${ T} , $ { linear_map_kind } : ( ${ T} ) - > ${ T} ) {
117
158
return ( log2 ( x) , { v in v / ( ${ T} ( M_LN2) * x) } )
118
159
}
119
160
120
161
@inlinable
121
162
@derivative( of: sin)
122
- func _vjpSin ( _ x: ${ T} ) -> ( value: ${ T} , pullback : ( ${ T} ) -> ${ T} ) {
163
+ func _$ { derivative_kind } Sin ( _ x: ${ T} ) - > ( value: ${ T} , $ { linear_map_kind } : ( ${ T} ) - > ${ T} ) {
123
164
return ( sin ( x) , { v in v * cos( x) } )
124
165
}
125
166
126
167
@inlinable
127
168
@derivative( of: cos)
128
- func _vjpCos ( _ x: ${ T} ) -> ( value: ${ T} , pullback : ( ${ T} ) -> ${ T} ) {
169
+ func _$ { derivative_kind } Cos ( _ x: ${ T} ) - > ( value: ${ T} , $ { linear_map_kind } : ( ${ T} ) - > ${ T} ) {
129
170
return ( cos ( x) , { v in - v * sin( x) } )
130
171
}
131
172
132
173
@inlinable
133
174
@derivative( of: tan)
134
- func _vjpTan ( _ x: ${ T} ) -> ( value: ${ T} , pullback : ( ${ T} ) -> ${ T} ) {
175
+ func _$ { derivative_kind } Tan ( _ x: ${ T} ) - > ( value: ${ T} , $ { linear_map_kind } : ( ${ T} ) - > ${ T} ) {
135
176
let value = tan ( x)
136
177
return ( value, { v in v * ( 1 + value * value) } )
137
178
}
138
179
139
180
@inlinable
140
181
@derivative( of: asin)
141
- func _vjpAsin ( _ x: ${ T} ) -> ( value: ${ T} , pullback : ( ${ T} ) -> ${ T} ) {
182
+ func _$ { derivative_kind } Asin ( _ x: ${ T} ) - > ( value: ${ T} , $ { linear_map_kind } : ( ${ T} ) - > ${ T} ) {
142
183
return ( asin ( x) , { v in v / sqrt( 1 - x * x) } )
143
184
}
144
185
145
186
@inlinable
146
187
@derivative( of: acos)
147
- func _vjpAcos ( _ x: ${ T} ) -> ( value: ${ T} , pullback : ( ${ T} ) -> ${ T} ) {
188
+ func _$ { derivative_kind } Acos ( _ x: ${ T} ) - > ( value: ${ T} , $ { linear_map_kind } : ( ${ T} ) - > ${ T} ) {
148
189
return ( acos ( x) , { v in - v / sqrt( 1 - x * x) } )
149
190
}
150
191
151
192
@inlinable
152
193
@derivative( of: atan)
153
- func _vjpAtan ( _ x: ${ T} ) -> ( value: ${ T} , pullback : ( ${ T} ) -> ${ T} ) {
194
+ func _$ { derivative_kind } Atan ( _ x: ${ T} ) - > ( value: ${ T} , $ { linear_map_kind } : ( ${ T} ) - > ${ T} ) {
154
195
return ( atan ( x) , { v in v / ( 1 + x * x) } )
155
196
}
156
197
157
198
@inlinable
158
199
@derivative( of: sinh)
159
- func _vjpSinh ( _ x: ${ T} ) -> ( value: ${ T} , pullback : ( ${ T} ) -> ${ T} ) {
200
+ func _$ { derivative_kind } Sinh ( _ x: ${ T} ) - > ( value: ${ T} , $ { linear_map_kind } : ( ${ T} ) - > ${ T} ) {
160
201
return ( sinh ( x) , { v in v * cosh( x) } )
161
202
}
162
203
163
204
@inlinable
164
205
@derivative( of: cosh)
165
- func _vjpCosh ( _ x: ${ T} ) -> ( value: ${ T} , pullback : ( ${ T} ) -> ${ T} ) {
206
+ func _$ { derivative_kind } Cosh ( _ x: ${ T} ) - > ( value: ${ T} , $ { linear_map_kind } : ( ${ T} ) - > ${ T} ) {
166
207
return ( cosh ( x) , { v in v * sinh( x) } )
167
208
}
168
209
169
210
@inlinable
170
211
@derivative( of: tanh)
171
- func _vjpTanh ( _ x: ${ T} ) -> ( value: ${ T} , pullback : ( ${ T} ) -> ${ T} ) {
212
+ func _$ { derivative_kind } Tanh ( _ x: ${ T} ) - > ( value: ${ T} , $ { linear_map_kind } : ( ${ T} ) - > ${ T} ) {
172
213
let value = tanh ( x)
173
214
return ( value, { v in v * ( 1 - value * value) } )
174
215
}
175
216
176
217
@inlinable
177
218
@derivative( of: asinh)
178
- func _vjpAsinh ( _ x: ${ T} ) -> ( value: ${ T} , pullback : ( ${ T} ) -> ${ T} ) {
219
+ func _$ { derivative_kind } Asinh ( _ x: ${ T} ) - > ( value: ${ T} , $ { linear_map_kind } : ( ${ T} ) - > ${ T} ) {
179
220
return ( asinh ( x) , { v in v / sqrt( 1 + x * x) } )
180
221
}
181
222
182
223
@inlinable
183
224
@derivative( of: acosh)
184
- func _vjpAcosh ( _ x: ${ T} ) -> ( value: ${ T} , pullback : ( ${ T} ) -> ${ T} ) {
225
+ func _$ { derivative_kind } Acosh ( _ x: ${ T} ) - > ( value: ${ T} , $ { linear_map_kind } : ( ${ T} ) - > ${ T} ) {
185
226
return ( acosh ( x) , { v in v / sqrt( x * x - 1 ) } )
186
227
}
187
228
188
229
@inlinable
189
230
@derivative( of: atanh)
190
- func _vjpAtanh ( _ x: ${ T} ) -> ( value: ${ T} , pullback : ( ${ T} ) -> ${ T} ) {
231
+ func _$ { derivative_kind } Atanh ( _ x: ${ T} ) - > ( value: ${ T} , $ { linear_map_kind } : ( ${ T} ) - > ${ T} ) {
191
232
return ( atanh ( x) , { v in v / ( 1 - x * x) } )
192
233
}
193
234
194
235
@inlinable
195
236
@derivative( of: expm1)
196
- func _vjpExpm 1 ( _ x: ${ T} ) -> ( value: ${ T} , pullback : ( ${ T} ) -> ${ T} ) {
237
+ func _$ { derivative_kind } Expm1 ( _ x: ${ T} ) - > ( value: ${ T} , $ { linear_map_kind } : ( ${ T} ) - > ${ T} ) {
197
238
return ( expm1 ( x) , { v in exp ( x) * v } )
198
239
}
199
240
200
241
@inlinable
201
242
@derivative( of: log1p)
202
- func _vjpLog 1 p ( _ x: ${ T} ) -> ( value: ${ T} , pullback : ( ${ T} ) -> ${ T} ) {
243
+ func _$ { derivative_kind } Log1p ( _ x: ${ T} ) - > ( value: ${ T} , $ { linear_map_kind } : ( ${ T} ) - > ${ T} ) {
203
244
return ( log1p ( x) , { v in v / ( x + 1 ) } )
204
245
}
205
246
206
247
@inlinable
207
248
@derivative( of: erf)
208
- func _vjpErf ( _ x: ${ T} ) -> ( value: ${ T} , pullback : ( ${ T} ) -> ${ T} ) {
249
+ func _$ { derivative_kind } Erf ( _ x: ${ T} ) - > ( value: ${ T} , $ { linear_map_kind } : ( ${ T} ) - > ${ T} ) {
209
250
return ( erf ( x) , { v in v * ${ T} ( M_2_SQRTPI) * exp( - x * x) } )
210
251
}
211
252
212
253
@inlinable
213
254
@derivative( of: erfc)
214
- func _vjpErfc ( _ x: ${ T} ) -> ( value: ${ T} , pullback : ( ${ T} ) -> ${ T} ) {
255
+ func _$ { derivative_kind } Erfc ( _ x: ${ T} ) - > ( value: ${ T} , $ { linear_map_kind } : ( ${ T} ) - > ${ T} ) {
215
256
return ( erfc ( x) , { v in v * - ${ T} ( M_2_SQRTPI) * exp( - x * x) } )
216
257
}
217
258
218
- % if T == 'Float80 ':
259
+ % if T == 'Float80 ':
219
260
#endif
220
- % end
221
- % end
261
+ % end # if T == 'Float80 ':
262
+ % end # for T in [ 'Float', 'Double', 'Float80 '] :
263
+ % end # for derivative_kind in [ 'jvp', 'vjp'] :
0 commit comments