@@ -102,11 +102,27 @@ func _vjpExp(_ x: ${T}) -> (${T}, (${T}) -> ${T}) {
102
102
return ( value, { v in value * v } )
103
103
}
104
104
105
+ @usableFromInline
106
+ func _vjpExp2 ( _ x: ${ T} ) - > ( ${ T} , ( ${ T} ) - > ${ T} ) {
107
+ let value = exp2 ( x)
108
+ return ( value, { v in v * ${ T} ( M_LN2) * value } )
109
+ }
110
+
105
111
@usableFromInline
106
112
func _vjpLog( _ x: ${ T} ) -> ( ${ T} , ( ${ T} ) -> ${ T} ) {
107
113
return ( log( x ) , { v in v / x } )
108
114
}
109
115
116
+ @usableFromInline
117
+ func _vjpLog10( _ x: ${ T} ) -> ( ${ T} , ( ${ T} ) -> ${ T} ) {
118
+ return ( log10( x ) , { v in v * ${ T} ( M_LOG10E ) / x} )
119
+ }
120
+
121
+ @usableFromInline
122
+ func _vjpLog2( _ x: ${ T} ) -> ( ${ T} , ( ${ T} ) -> ${ T} ) {
123
+ return ( log2( x ) , { v in v / ( ${ T} ( M_LN2 ) * x) } )
124
+ }
125
+
110
126
@usableFromInline
111
127
func _vjpSin( _ x: ${ T} ) -> ( ${ T} , ( ${ T} ) -> ${ T} ) {
112
128
return ( sin( x ) , { v in v * cos( x ) } )
@@ -122,6 +138,72 @@ func _vjpTan(_ x: ${T}) -> (${T}, (${T}) -> ${T}) {
122
138
let value = tan( x )
123
139
return ( value , { v in v * ( 1 + value * value) } )
124
140
}
141
+
142
+ @usableFromInline
143
+ func _vjpAsin( _ x: ${ T} ) -> ( ${ T} , ( ${ T} ) -> ${ T} ) {
144
+ return ( asin( x ) , { v in v / sqrt( 1 - x * x) } )
145
+ }
146
+
147
+ @usableFromInline
148
+ func _vjpAcos( _ x: ${ T} ) -> ( ${ T} , ( ${ T} ) -> ${ T} ) {
149
+ return ( acos( x ) , { v in - v / sqrt( 1 - x * x) } )
150
+ }
151
+
152
+ @usableFromInline
153
+ func _vjpAtan( _ x: ${ T} ) -> ( ${ T} , ( ${ T} ) -> ${ T} ) {
154
+ return ( atan( x ) , { v in v / ( 1 + x * x) } )
155
+ }
156
+
157
+ @usableFromInline
158
+ func _vjpSinh( _ x: ${ T} ) -> ( ${ T} , ( ${ T} ) -> ${ T} ) {
159
+ return ( sinh( x ) , { v in v * cosh( x ) } )
160
+ }
161
+
162
+ @usableFromInline
163
+ func _vjpCosh( _ x: ${ T} ) -> ( ${ T} , ( ${ T} ) -> ${ T} ) {
164
+ return ( cosh( x ) , { v in v * sinh( x ) } )
165
+ }
166
+
167
+ @usableFromInline
168
+ func _vjpTanh( _ x: ${ T} ) -> ( ${ T} , ( ${ T} ) -> ${ T} ) {
169
+ let value = tanh( x )
170
+ return ( value , { v in v * ( 1 - value * value) } )
171
+ }
172
+
173
+ @usableFromInline
174
+ func _vjpAsinh( _ x: ${ T} ) -> ( ${ T} , ( ${ T} ) -> ${ T} ) {
175
+ return ( asinh( x ) , { v in v / sqrt( 1 + x * x) } )
176
+ }
177
+
178
+ @usableFromInline
179
+ func _vjpAcosh( _ x: ${ T} ) -> ( ${ T} , ( ${ T} ) -> ${ T} ) {
180
+ return ( acosh( x ) , { v in v / sqrt( x * x - 1 ) } )
181
+ }
182
+
183
+ @usableFromInline
184
+ func _vjpAtanh( _ x: ${ T} ) -> ( ${ T} , ( ${ T} ) -> ${ T} ) {
185
+ return ( atanh( x ) , { v in v / ( 1 - x * x) } )
186
+ }
187
+
188
+ @usableFromInline
189
+ func _vjpExpm1( _ x: ${ T} ) -> ( ${ T} , ( ${ T} ) -> ${ T} ) {
190
+ return ( expm1( x ) , { v in exp( x ) * v } )
191
+ }
192
+
193
+ @usableFromInline
194
+ func _vjpLog1p( _ x: ${ T} ) -> ( ${ T} , ( ${ T} ) -> ${ T} ) {
195
+ return ( log1p( x ) , { v in v / ( x + 1 ) } )
196
+ }
197
+
198
+ @usableFromInline
199
+ func _vjpErf( _ x: ${ T} ) -> ( ${ T} , ( ${ T} ) -> ${ T} ) {
200
+ return ( erf( x ) , { v in v * ${ T} ( M_2_SQRTPI ) * exp( - x * x) } )
201
+ }
202
+
203
+ @usableFromInline
204
+ func _vjpErfc( _ x: ${ T} ) -> ( ${ T} , ( ${ T} ) -> ${ T} ) {
205
+ return ( erfc( x ) , { v in v * - ${ T} ( M_2_SQRTPI ) * exp( - x * x) } )
206
+ }
125
207
% if T == 'Float80 ':
126
208
#endif
127
209
% end
@@ -201,7 +283,14 @@ UnaryIntrinsicFunctions = [
201
283
]
202
284
203
285
# SWIFT_ENABLE_TENSORFLOW
204
- HasVJP = [ " exp " , " log " , " tan " , " cos " , " sin " ]
286
+ HasVJP = [
287
+ 'acos', 'asin', 'atan', 'tan',
288
+ 'acosh', 'asinh', 'atanh', 'cosh', 'sinh', 'tanh',
289
+ 'expm1 ',
290
+ 'log1 p',
291
+ 'erf', 'erfc',
292
+ 'cos', 'sin', 'exp', 'exp2 ', 'log', 'log10 ', 'log2 '
293
+ ]
205
294
206
295
def AllFloatTypes( ) :
207
296
for bits in allFloatBits:
0 commit comments