@@ -20,11 +20,22 @@ public func fabs<T: FloatingPoint>(_ x: T) -> T {
20
20
}
21
21
22
22
@_transparent
23
+ // SWIFT_ENABLE_TENSORFLOW
24
+ @differentiable (
25
+ vjp: _vjpSqrt
26
+ where T : Differentiable & FloatingPoint, T == T . TangentVector
27
+ )
23
28
public func sqrt< T: FloatingPoint > ( _ x: T ) -> T {
24
29
return x. squareRoot ( )
25
30
}
26
31
27
32
@_transparent
33
+ // SWIFT_ENABLE_TENSORFLOW
34
+ @differentiable (
35
+ wrt: ( x, y, z) ,
36
+ vjp: _vjpFma
37
+ where T : Differentiable & FloatingPoint, T == T . TangentVector
38
+ )
28
39
public func fma< T: FloatingPoint > ( _ x: T , _ y: T , _ z: T ) -> T {
29
40
return z. addingProduct ( x, y)
30
41
}
@@ -82,6 +93,24 @@ public func frexp<T: BinaryFloatingPoint>(_ x: T) -> (T, Int) {
82
93
return ( x. significand / 2 , Int ( x. exponent + 1 ) )
83
94
}
84
95
96
+ // SWIFT_ENABLE_TENSORFLOW
97
+ @usableFromInline
98
+ func _vjpSqrt< T: FloatingPoint & Differentiable > (
99
+ _ x: T
100
+ ) -> ( T , ( T ) -> T ) where T == T . TangentVector {
101
+ let value = x. squareRoot ( )
102
+ return ( value, { v in v / ( 2 * value) } )
103
+ }
104
+
105
+ @usableFromInline
106
+ func _vjpFma< T: FloatingPoint & Differentiable > (
107
+ _ x: T ,
108
+ _ y: T ,
109
+ _ z: T
110
+ ) -> ( T , ( T ) -> ( T , T , T ) ) where T == T . TangentVector {
111
+ return ( fma ( x, y, z) , { v in ( v * y, v * x, v) } )
112
+ }
113
+
85
114
% for T in [ 'Float', 'Double'] :
86
115
@available ( swift, deprecated: 4.2 , renamed: " scalbn " )
87
116
@_transparent
@@ -102,11 +131,27 @@ func _vjpExp(_ x: ${T}) -> (${T}, (${T}) -> ${T}) {
102
131
return ( value, { v in value * v } )
103
132
}
104
133
134
+ @usableFromInline
135
+ func _vjpExp2 ( _ x: ${ T} ) - > ( ${ T} , ( ${ T} ) - > ${ T} ) {
136
+ let value = exp2 ( x)
137
+ return ( value, { v in v * ${ T} ( M_LN2) * value } )
138
+ }
139
+
105
140
@usableFromInline
106
141
func _vjpLog( _ x: ${ T} ) -> ( ${ T} , ( ${ T} ) -> ${ T} ) {
107
142
return ( log( x ) , { v in v / x } )
108
143
}
109
144
145
+ @usableFromInline
146
+ func _vjpLog10( _ x: ${ T} ) -> ( ${ T} , ( ${ T} ) -> ${ T} ) {
147
+ return ( log10( x ) , { v in v * ${ T} ( M_LOG10E ) / x } )
148
+ }
149
+
150
+ @usableFromInline
151
+ func _vjpLog2( _ x: ${ T} ) -> ( ${ T} , ( ${ T} ) -> ${ T} ) {
152
+ return ( log2( x ) , { v in v / ( ${ T} ( M_LN2 ) * x) } )
153
+ }
154
+
110
155
@usableFromInline
111
156
func _vjpSin( _ x: ${ T} ) -> ( ${ T} , ( ${ T} ) -> ${ T} ) {
112
157
return ( sin( x ) , { v in v * cos( x ) } )
@@ -122,6 +167,72 @@ func _vjpTan(_ x: ${T}) -> (${T}, (${T}) -> ${T}) {
122
167
let value = tan( x )
123
168
return ( value , { v in v * ( 1 + value * value) } )
124
169
}
170
+
171
+ @usableFromInline
172
+ func _vjpAsin( _ x: ${ T} ) -> ( ${ T} , ( ${ T} ) -> ${ T} ) {
173
+ return ( asin( x ) , { v in v / sqrt( 1 - x * x) } )
174
+ }
175
+
176
+ @usableFromInline
177
+ func _vjpAcos( _ x: ${ T} ) -> ( ${ T} , ( ${ T} ) -> ${ T} ) {
178
+ return ( acos( x ) , { v in - v / sqrt( 1 - x * x) } )
179
+ }
180
+
181
+ @usableFromInline
182
+ func _vjpAtan( _ x: ${ T} ) -> ( ${ T} , ( ${ T} ) -> ${ T} ) {
183
+ return ( atan( x ) , { v in v / ( 1 + x * x) } )
184
+ }
185
+
186
+ @usableFromInline
187
+ func _vjpSinh( _ x: ${ T} ) -> ( ${ T} , ( ${ T} ) -> ${ T} ) {
188
+ return ( sinh( x ) , { v in v * cosh( x ) } )
189
+ }
190
+
191
+ @usableFromInline
192
+ func _vjpCosh( _ x: ${ T} ) -> ( ${ T} , ( ${ T} ) -> ${ T} ) {
193
+ return ( cosh( x ) , { v in v * sinh( x ) } )
194
+ }
195
+
196
+ @usableFromInline
197
+ func _vjpTanh( _ x: ${ T} ) -> ( ${ T} , ( ${ T} ) -> ${ T} ) {
198
+ let value = tanh( x )
199
+ return ( value , { v in v * ( 1 - value * value) } )
200
+ }
201
+
202
+ @usableFromInline
203
+ func _vjpAsinh( _ x: ${ T} ) -> ( ${ T} , ( ${ T} ) -> ${ T} ) {
204
+ return ( asinh( x ) , { v in v / sqrt( 1 + x * x) } )
205
+ }
206
+
207
+ @usableFromInline
208
+ func _vjpAcosh( _ x: ${ T} ) -> ( ${ T} , ( ${ T} ) -> ${ T} ) {
209
+ return ( acosh( x ) , { v in v / sqrt( x * x - 1 ) } )
210
+ }
211
+
212
+ @usableFromInline
213
+ func _vjpAtanh( _ x: ${ T} ) -> ( ${ T} , ( ${ T} ) -> ${ T} ) {
214
+ return ( atanh( x ) , { v in v / ( 1 - x * x) } )
215
+ }
216
+
217
+ @usableFromInline
218
+ func _vjpExpm1( _ x: ${ T} ) -> ( ${ T} , ( ${ T} ) -> ${ T} ) {
219
+ return ( expm1( x ) , { v in exp( x ) * v } )
220
+ }
221
+
222
+ @usableFromInline
223
+ func _vjpLog1p( _ x: ${ T} ) -> ( ${ T} , ( ${ T} ) -> ${ T} ) {
224
+ return ( log1p( x ) , { v in v / ( x + 1 ) } )
225
+ }
226
+
227
+ @usableFromInline
228
+ func _vjpErf( _ x: ${ T} ) -> ( ${ T} , ( ${ T} ) -> ${ T} ) {
229
+ return ( erf( x ) , { v in v * ${ T} ( M_2_SQRTPI ) * exp( - x * x) } )
230
+ }
231
+
232
+ @usableFromInline
233
+ func _vjpErfc( _ x: ${ T} ) -> ( ${ T} , ( ${ T} ) -> ${ T} ) {
234
+ return ( erfc( x ) , { v in v * - ${ T} ( M_2_SQRTPI ) * exp( - x * x) } )
235
+ }
125
236
% if T == 'Float80 ':
126
237
#endif
127
238
% end
@@ -201,7 +312,14 @@ UnaryIntrinsicFunctions = [
201
312
]
202
313
203
314
# SWIFT_ENABLE_TENSORFLOW
204
- HasVJP = [ " exp " , " log " , " tan " , " cos " , " sin " ]
315
+ HasVJP = [
316
+ 'acos', 'asin', 'atan', 'tan',
317
+ 'acosh', 'asinh', 'atanh', 'cosh', 'sinh', 'tanh',
318
+ 'expm1 ',
319
+ 'log1 p',
320
+ 'erf', 'erfc',
321
+ 'cos', 'sin', 'exp', 'exp2 ', 'log', 'log10 ', 'log2 '
322
+ ]
205
323
206
324
def AllFloatTypes( ) :
207
325
for bits in allFloatBits:
0 commit comments