Skip to content

Commit 79e0fea

Browse files
committed
Defines derivatives for remaning tgmath math functions.
* Inverse trigonometric functions. * Exponents and logarithms. * Hyperbolic functions. * Error functions ( erf, erfc ).
1 parent ec21e28 commit 79e0fea

File tree

2 files changed

+107
-2
lines changed

2 files changed

+107
-2
lines changed

stdlib/public/Platform/tgmath.swift.gyb

Lines changed: 90 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,11 +102,27 @@ func _vjpExp(_ x: ${T}) -> (${T}, (${T}) -> ${T}) {
102102
return (value, { v in value * v })
103103
}
104104

105+
@usableFromInline
106+
func _vjpExp2(_ x: ${T}) -> (${T}, (${T}) -> ${T}) {
107+
let value = exp2(x)
108+
return (value, { v in v * ${T}(M_LN2) * value })
109+
}
110+
105111
@usableFromInline
106112
func _vjpLog(_ x: ${T}) -> (${T}, (${T}) -> ${T}) {
107113
return (log(x), { v in v / x })
108114
}
109115

116+
@usableFromInline
117+
func _vjpLog10(_ x: ${T}) -> (${T}, (${T}) -> ${T}) {
118+
return (log10(x), { v in v * ${T}(M_LOG10E) / x})
119+
}
120+
121+
@usableFromInline
122+
func _vjpLog2(_ x: ${T}) -> (${T}, (${T}) -> ${T}) {
123+
return (log2(x), { v in v / (${T}(M_LN2) * x)})
124+
}
125+
110126
@usableFromInline
111127
func _vjpSin(_ x: ${T}) -> (${T}, (${T}) -> ${T}) {
112128
return (sin(x), { v in v * cos(x) })
@@ -122,6 +138,72 @@ func _vjpTan(_ x: ${T}) -> (${T}, (${T}) -> ${T}) {
122138
let value = tan(x)
123139
return (value, { v in v * (1 + value * value) })
124140
}
141+
142+
@usableFromInline
143+
func _vjpAsin(_ x: ${T}) -> (${T}, (${T}) -> ${T}) {
144+
return (asin(x), { v in v / sqrt(1 - x * x) })
145+
}
146+
147+
@usableFromInline
148+
func _vjpAcos(_ x: ${T}) -> (${T}, (${T}) -> ${T}) {
149+
return (acos(x), { v in -v / sqrt(1 - x * x) })
150+
}
151+
152+
@usableFromInline
153+
func _vjpAtan(_ x: ${T}) -> (${T}, (${T}) -> ${T}) {
154+
return (atan(x), { v in v / (1 + x * x) })
155+
}
156+
157+
@usableFromInline
158+
func _vjpSinh(_ x: ${T}) -> (${T}, (${T}) -> ${T}) {
159+
return (sinh(x), { v in v * cosh(x) })
160+
}
161+
162+
@usableFromInline
163+
func _vjpCosh(_ x: ${T}) -> (${T}, (${T}) -> ${T}) {
164+
return (cosh(x), { v in v * sinh(x) })
165+
}
166+
167+
@usableFromInline
168+
func _vjpTanh(_ x: ${T}) -> (${T}, (${T}) -> ${T}) {
169+
let value = tanh(x)
170+
return (value, { v in v * (1 - value * value) })
171+
}
172+
173+
@usableFromInline
174+
func _vjpAsinh(_ x: ${T}) -> (${T}, (${T}) -> ${T}) {
175+
return (asinh(x), { v in v / sqrt(1 + x * x) })
176+
}
177+
178+
@usableFromInline
179+
func _vjpAcosh(_ x: ${T}) -> (${T}, (${T}) -> ${T}) {
180+
return (acosh(x), { v in v / sqrt(x * x - 1) })
181+
}
182+
183+
@usableFromInline
184+
func _vjpAtanh(_ x: ${T}) -> (${T}, (${T}) -> ${T}) {
185+
return (atanh(x), { v in v / (1 - x * x) })
186+
}
187+
188+
@usableFromInline
189+
func _vjpExpm1(_ x: ${T}) -> (${T}, (${T}) -> ${T}) {
190+
return (expm1(x), { v in exp(x) * v })
191+
}
192+
193+
@usableFromInline
194+
func _vjpLog1p(_ x: ${T}) -> (${T}, (${T}) -> ${T}) {
195+
return (log1p(x), { v in v / ( x + 1) })
196+
}
197+
198+
@usableFromInline
199+
func _vjpErf(_ x: ${T}) -> (${T}, (${T}) -> ${T}) {
200+
return (erf(x), { v in v * ${T}(M_2_SQRTPI) * exp(-x * x) })
201+
}
202+
203+
@usableFromInline
204+
func _vjpErfc(_ x: ${T}) -> (${T}, (${T}) -> ${T}) {
205+
return (erfc(x), { v in v * -${T}(M_2_SQRTPI) * exp(-x * x) })
206+
}
125207
% if T == 'Float80':
126208
#endif
127209
% end
@@ -201,7 +283,14 @@ UnaryIntrinsicFunctions = [
201283
]
202284

203285
# SWIFT_ENABLE_TENSORFLOW
204-
HasVJP = ["exp", "log", "tan", "cos", "sin"]
286+
HasVJP = [
287+
'acos', 'asin', 'atan', 'tan',
288+
'acosh', 'asinh', 'atanh', 'cosh', 'sinh', 'tanh',
289+
'expm1',
290+
'log1p',
291+
'erf', 'erfc',
292+
'cos', 'sin', 'exp', 'exp2', 'log', 'log10', 'log2'
293+
]
205294

206295
def AllFloatTypes():
207296
for bits in allFloatBits:

test/stdlib/tgmath.swift.gyb

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,10 +248,26 @@ MathTests.test("${T}") {
248248
% for T in ['Float', 'Float80']:
249249
MathTests.test("gradient_${T}") {
250250
expectEqualWithTolerance(7.3890560989306502274, gradient(at: 2.0 as ${T}, in: exp), ulps:16)
251+
expectEqualWithTolerance(2.772588722239781145, gradient(at: 2.0 as ${T}, in: exp2), ulps:16)
252+
expectEqualWithTolerance(7.3890560989306502274, gradient(at: 2.0 as ${T}, in: expm1), ulps:16)
251253
expectEqualWithTolerance(0.5, gradient(at: 2.0 as ${T}, in: log), ulps:16)
254+
expectEqualWithTolerance(0.21714724095162590833, gradient(at: 2.0 as ${T}, in: log10), ulps:16)
255+
expectEqualWithTolerance(0.7213475204444817278, gradient(at: 2.0 as ${T}, in: log2), ulps:16)
256+
expectEqualWithTolerance(0.33333333333333333334, gradient(at: 2.0 as ${T}, in: log1p), ulps:16)
252257
expectEqualWithTolerance(5.774399204041917612, gradient(at: 2.0 as ${T}, in: tan), ulps:16)
253-
expectEqualWithTolerance(-0.416146836547142387, gradient(at: 2.0 as ${T}, in: sin), ulps:16)
254258
expectEqualWithTolerance(-0.9092974268256816954, gradient(at: 2.0 as ${T}, in: cos), ulps:16)
259+
expectEqualWithTolerance(-0.416146836547142387, gradient(at: 2.0 as ${T}, in: sin), ulps:16)
260+
expectEqualWithTolerance(1.154700538379251529, gradient(at: 0.5 as ${T}, in: asin), ulps:16)
261+
expectEqualWithTolerance(-1.154700538379251529, gradient(at: 0.5 as ${T}, in: acos), ulps:16)
262+
expectEqualWithTolerance(0.8, gradient(at: 0.5 as ${T}, in: atan), ulps:16)
263+
expectEqualWithTolerance(3.7621956910836314597, gradient(at: 2.0 as ${T}, in: sinh), ulps:16)
264+
expectEqualWithTolerance(3.6268604078470187677, gradient(at: 2.0 as ${T}, in: cosh), ulps:16)
265+
expectEqualWithTolerance(0.07065082485316446565, gradient(at: 2.0 as ${T}, in: tanh), ulps:16)
266+
expectEqualWithTolerance(0.44721359549995793928, gradient(at: 2.0 as ${T}, in: asinh), ulps:16)
267+
expectEqualWithTolerance(0.5773502691896257645, gradient(at: 2.0 as ${T}, in: acosh), ulps:16)
268+
expectEqualWithTolerance(1.3333333333333333334, gradient(at: 0.5 as ${T}, in: atanh), ulps:16)
269+
expectEqualWithTolerance(0.020666985354092053575, gradient(at: 2.0 as ${T}, in: erf), ulps:16)
270+
expectEqualWithTolerance(-0.020666985354092053575, gradient(at: 2.0 as ${T}, in: erfc), ulps:16)
255271
}
256272
%end
257273

0 commit comments

Comments
 (0)