Skip to content

Commit c76bde9

Browse files
vguerrarxwei
authored andcommitted
Defines derivatives for remaining tgmath math functions. (#27559)
* Inverse trigonometric functions. * Exponents and logarithms. * Hyperbolic functions. * Error functions ( erf, erfc ). * Free generic functions: sqrt, fma. Partially resolves [TF-812](https://bugs.swift.org/browse/TF-812).
1 parent 4424a7c commit c76bde9

File tree

2 files changed

+145
-6
lines changed

2 files changed

+145
-6
lines changed

stdlib/public/Platform/tgmath.swift.gyb

Lines changed: 119 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,22 @@ public func fabs<T: FloatingPoint>(_ x: T) -> T {
2020
}
2121

2222
@_transparent
23+
// SWIFT_ENABLE_TENSORFLOW
24+
@differentiable(
25+
vjp: _vjpSqrt
26+
where T : Differentiable & FloatingPoint, T == T.TangentVector
27+
)
2328
public func sqrt<T: FloatingPoint>(_ x: T) -> T {
2429
return x.squareRoot()
2530
}
2631

2732
@_transparent
33+
// SWIFT_ENABLE_TENSORFLOW
34+
@differentiable(
35+
wrt: (x, y, z),
36+
vjp: _vjpFma
37+
where T : Differentiable & FloatingPoint, T == T.TangentVector
38+
)
2839
public func fma<T: FloatingPoint>(_ x: T, _ y: T, _ z: T) -> T {
2940
return z.addingProduct(x, y)
3041
}
@@ -82,6 +93,24 @@ public func frexp<T: BinaryFloatingPoint>(_ x: T) -> (T, Int) {
8293
return (x.significand / 2, Int(x.exponent + 1))
8394
}
8495

96+
// SWIFT_ENABLE_TENSORFLOW
97+
@usableFromInline
98+
func _vjpSqrt<T: FloatingPoint & Differentiable> (
99+
_ x: T
100+
) -> (T, (T) -> T) where T == T.TangentVector {
101+
let value = x.squareRoot()
102+
return (value, { v in v / (2 * value) })
103+
}
104+
105+
@usableFromInline
106+
func _vjpFma<T: FloatingPoint & Differentiable> (
107+
_ x: T,
108+
_ y: T,
109+
_ z: T
110+
) -> (T, (T) -> (T, T, T)) where T == T.TangentVector {
111+
return (fma(x, y, z), { v in (v * y, v * x, v) })
112+
}
113+
85114
%for T in ['Float','Double']:
86115
@available(swift, deprecated: 4.2, renamed: "scalbn")
87116
@_transparent
@@ -102,11 +131,27 @@ func _vjpExp(_ x: ${T}) -> (${T}, (${T}) -> ${T}) {
102131
return (value, { v in value * v })
103132
}
104133

134+
@usableFromInline
135+
func _vjpExp2(_ x: ${T}) -> (${T}, (${T}) -> ${T}) {
136+
let value = exp2(x)
137+
return (value, { v in v * ${T}(M_LN2) * value })
138+
}
139+
105140
@usableFromInline
106141
func _vjpLog(_ x: ${T}) -> (${T}, (${T}) -> ${T}) {
107142
return (log(x), { v in v / x })
108143
}
109144

145+
@usableFromInline
146+
func _vjpLog10(_ x: ${T}) -> (${T}, (${T}) -> ${T}) {
147+
return (log10(x), { v in v * ${T}(M_LOG10E) / x })
148+
}
149+
150+
@usableFromInline
151+
func _vjpLog2(_ x: ${T}) -> (${T}, (${T}) -> ${T}) {
152+
return (log2(x), { v in v / (${T}(M_LN2) * x) })
153+
}
154+
110155
@usableFromInline
111156
func _vjpSin(_ x: ${T}) -> (${T}, (${T}) -> ${T}) {
112157
return (sin(x), { v in v * cos(x) })
@@ -122,6 +167,72 @@ func _vjpTan(_ x: ${T}) -> (${T}, (${T}) -> ${T}) {
122167
let value = tan(x)
123168
return (value, { v in v * (1 + value * value) })
124169
}
170+
171+
@usableFromInline
172+
func _vjpAsin(_ x: ${T}) -> (${T}, (${T}) -> ${T}) {
173+
return (asin(x), { v in v / sqrt(1 - x * x) })
174+
}
175+
176+
@usableFromInline
177+
func _vjpAcos(_ x: ${T}) -> (${T}, (${T}) -> ${T}) {
178+
return (acos(x), { v in -v / sqrt(1 - x * x) })
179+
}
180+
181+
@usableFromInline
182+
func _vjpAtan(_ x: ${T}) -> (${T}, (${T}) -> ${T}) {
183+
return (atan(x), { v in v / (1 + x * x) })
184+
}
185+
186+
@usableFromInline
187+
func _vjpSinh(_ x: ${T}) -> (${T}, (${T}) -> ${T}) {
188+
return (sinh(x), { v in v * cosh(x) })
189+
}
190+
191+
@usableFromInline
192+
func _vjpCosh(_ x: ${T}) -> (${T}, (${T}) -> ${T}) {
193+
return (cosh(x), { v in v * sinh(x) })
194+
}
195+
196+
@usableFromInline
197+
func _vjpTanh(_ x: ${T}) -> (${T}, (${T}) -> ${T}) {
198+
let value = tanh(x)
199+
return (value, { v in v * (1 - value * value) })
200+
}
201+
202+
@usableFromInline
203+
func _vjpAsinh(_ x: ${T}) -> (${T}, (${T}) -> ${T}) {
204+
return (asinh(x), { v in v / sqrt(1 + x * x) })
205+
}
206+
207+
@usableFromInline
208+
func _vjpAcosh(_ x: ${T}) -> (${T}, (${T}) -> ${T}) {
209+
return (acosh(x), { v in v / sqrt(x * x - 1) })
210+
}
211+
212+
@usableFromInline
213+
func _vjpAtanh(_ x: ${T}) -> (${T}, (${T}) -> ${T}) {
214+
return (atanh(x), { v in v / (1 - x * x) })
215+
}
216+
217+
@usableFromInline
218+
func _vjpExpm1(_ x: ${T}) -> (${T}, (${T}) -> ${T}) {
219+
return (expm1(x), { v in exp(x) * v })
220+
}
221+
222+
@usableFromInline
223+
func _vjpLog1p(_ x: ${T}) -> (${T}, (${T}) -> ${T}) {
224+
return (log1p(x), { v in v / (x + 1) })
225+
}
226+
227+
@usableFromInline
228+
func _vjpErf(_ x: ${T}) -> (${T}, (${T}) -> ${T}) {
229+
return (erf(x), { v in v * ${T}(M_2_SQRTPI) * exp(-x * x) })
230+
}
231+
232+
@usableFromInline
233+
func _vjpErfc(_ x: ${T}) -> (${T}, (${T}) -> ${T}) {
234+
return (erfc(x), { v in v * -${T}(M_2_SQRTPI) * exp(-x * x) })
235+
}
125236
% if T == 'Float80':
126237
#endif
127238
% end
@@ -201,7 +312,14 @@ UnaryIntrinsicFunctions = [
201312
]
202313

203314
# SWIFT_ENABLE_TENSORFLOW
204-
HasVJP = ["exp", "log", "tan", "cos", "sin"]
315+
HasVJP = [
316+
'acos', 'asin', 'atan', 'tan',
317+
'acosh', 'asinh', 'atanh', 'cosh', 'sinh', 'tanh',
318+
'expm1',
319+
'log1p',
320+
'erf', 'erfc',
321+
'cos', 'sin', 'exp', 'exp2', 'log', 'log10', 'log2'
322+
]
205323

206324
def AllFloatTypes():
207325
for bits in allFloatBits:

test/stdlib/tgmath.swift.gyb

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -247,11 +247,32 @@ MathTests.test("${T}") {
247247
// SWIFT_ENABLE_TENSORFLOW
248248
% for T in ['Float', 'Float80']:
249249
MathTests.test("gradient_${T}") {
250-
expectEqualWithTolerance(7.3890560989306502274, gradient(at: 2.0 as ${T}, in: exp), ulps:16)
251-
expectEqualWithTolerance(0.5, gradient(at: 2.0 as ${T}, in: log), ulps:16)
252-
expectEqualWithTolerance(5.774399204041917612, gradient(at: 2.0 as ${T}, in: tan), ulps:16)
253-
expectEqualWithTolerance(-0.416146836547142387, gradient(at: 2.0 as ${T}, in: sin), ulps:16)
254-
expectEqualWithTolerance(-0.9092974268256816954, gradient(at: 2.0 as ${T}, in: cos), ulps:16)
250+
expectEqualWithTolerance(7.3890560989306502274, gradient(at: 2.0 as ${T}, in: exp), ulps: 16)
251+
expectEqualWithTolerance(2.772588722239781145, gradient(at: 2.0 as ${T}, in: exp2), ulps: 16)
252+
expectEqualWithTolerance(7.3890560989306502274, gradient(at: 2.0 as ${T}, in: expm1), ulps: 16)
253+
expectEqualWithTolerance(0.5, gradient(at: 2.0 as ${T}, in: log), ulps: 16)
254+
expectEqualWithTolerance(0.21714724095162590833, gradient(at: 2.0 as ${T}, in: log10), ulps: 16)
255+
expectEqualWithTolerance(0.7213475204444817278, gradient(at: 2.0 as ${T}, in: log2), ulps: 16)
256+
expectEqualWithTolerance(0.33333333333333333334, gradient(at: 2.0 as ${T}, in: log1p), ulps: 16)
257+
expectEqualWithTolerance(5.774399204041917612, gradient(at: 2.0 as ${T}, in: tan), ulps: 16)
258+
expectEqualWithTolerance(-0.9092974268256816954, gradient(at: 2.0 as ${T}, in: cos), ulps: 16)
259+
expectEqualWithTolerance(-0.416146836547142387, gradient(at: 2.0 as ${T}, in: sin), ulps: 16)
260+
expectEqualWithTolerance(1.154700538379251529, gradient(at: 0.5 as ${T}, in: asin), ulps: 16)
261+
expectEqualWithTolerance(-1.154700538379251529, gradient(at: 0.5 as ${T}, in: acos), ulps: 16)
262+
expectEqualWithTolerance(0.8, gradient(at: 0.5 as ${T}, in: atan), ulps: 16)
263+
expectEqualWithTolerance(3.7621956910836314597, gradient(at: 2.0 as ${T}, in: sinh), ulps: 16)
264+
expectEqualWithTolerance(3.6268604078470187677, gradient(at: 2.0 as ${T}, in: cosh), ulps: 16)
265+
expectEqualWithTolerance(0.07065082485316446565, gradient(at: 2.0 as ${T}, in: tanh), ulps: 16)
266+
expectEqualWithTolerance(0.44721359549995793928, gradient(at: 2.0 as ${T}, in: asinh), ulps: 16)
267+
expectEqualWithTolerance(0.5773502691896257645, gradient(at: 2.0 as ${T}, in: acosh), ulps: 16)
268+
expectEqualWithTolerance(1.3333333333333333334, gradient(at: 0.5 as ${T}, in: atanh), ulps: 16)
269+
expectEqualWithTolerance(0.020666985354092053575, gradient(at: 2.0 as ${T}, in: erf), ulps: 16)
270+
expectEqualWithTolerance(-0.020666985354092053575, gradient(at: 2.0 as ${T}, in: erfc), ulps: 16)
271+
expectEqualWithTolerance(0.35355339059327376222, gradient(at: 2.0 as ${T}, in: { sqrt($0) }), ulps: 16)
272+
let fmaGrad = gradient(at: 4.0 as ${T}, 5.0 as ${T}, 6.0 as ${T}, in: { x, y, z in fma(x, y, z) })
273+
expectEqualWithTolerance(5.0, fmaGrad.0, ulps: 16)
274+
expectEqualWithTolerance(4.0, fmaGrad.1, ulps: 16)
275+
expectEqualWithTolerance(1.0, fmaGrad.2, ulps: 16)
255276
}
256277
%end
257278

0 commit comments

Comments
 (0)