Skip to content

Commit 62c019b

Browse files
committed
Adding derivatives for generic functions: sqrt and fma.
with respective tests.
1 parent c92e34b commit 62c019b

File tree

2 files changed

+33
-0
lines changed

2 files changed

+33
-0
lines changed

stdlib/public/Platform/tgmath.swift.gyb

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,20 @@ public func fabs<T: FloatingPoint>(_ x: T) -> T {
2020
}
2121

2222
@_transparent
23+
@differentiable(
24+
vjp: _vjpSqrt
25+
where T : Differentiable & FloatingPoint, T == T.TangentVector
26+
)
2327
public func sqrt<T: FloatingPoint>(_ x: T) -> T {
2428
return x.squareRoot()
2529
}
2630

2731
@_transparent
32+
@differentiable(
33+
wrt: (x, y, z),
34+
vjp: _vjpFma
35+
where T : Differentiable & FloatingPoint, T == T.TangentVector
36+
)
2837
public func fma<T: FloatingPoint>(_ x: T, _ y: T, _ z: T) -> T {
2938
return z.addingProduct(x, y)
3039
}
@@ -82,6 +91,25 @@ public func frexp<T: BinaryFloatingPoint>(_ x: T) -> (T, Int) {
8291
return (x.significand / 2, Int(x.exponent + 1))
8392
}
8493

94+
// SWIFT_ENABLE_TENSORFLOW
95+
@usableFromInline
96+
func _vjpSqrt<T: FloatingPoint & Differentiable> (
97+
_ x: T
98+
) -> (T, (T) -> T) where T == T.TangentVector {
99+
let value = x.squareRoot()
100+
return (value, {v in (1 / 2) * ( 1 / value) * v})
101+
}
102+
103+
@usableFromInline
104+
func _vjpFma<T: FloatingPoint & Differentiable> (
105+
_ x: T,
106+
_ y: T,
107+
_ z: T
108+
) -> (T, (T) -> (T, T, T)) where T == T.TangentVector {
109+
return (fma(x, y, z),
110+
{v in return (v * y, v * x, v)})
111+
}
112+
85113
%for T in ['Float','Double']:
86114
@available(swift, deprecated: 4.2, renamed: "scalbn")
87115
@_transparent

test/stdlib/tgmath.swift.gyb

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,11 @@ MathTests.test("gradient_${T}") {
268268
expectEqualWithTolerance(1.3333333333333333334, gradient(at: 0.5 as ${T}, in: atanh), ulps: 16)
269269
expectEqualWithTolerance(0.020666985354092053575, gradient(at: 2.0 as ${T}, in: erf), ulps: 16)
270270
expectEqualWithTolerance(-0.020666985354092053575, gradient(at: 2.0 as ${T}, in: erfc), ulps: 16)
271+
expectEqualWithTolerance(0.35355339059327376222, gradient(at: 2.0 as ${T}, in: {x in sqrt(x)}), ulps: 16)
272+
let fma_grad = gradient(at: 4.0 as ${T}, 5.0 as ${T}, 6.0 as ${T}, in: {x, y, z in fma(x, y, z)})
273+
expectEqualWithTolerance(5.0, fma_grad.0, ulps: 16)
274+
expectEqualWithTolerance(4.0, fma_grad.1, ulps: 16)
275+
expectEqualWithTolerance(1.0, fma_grad.2, ulps: 16)
271276
}
272277
%end
273278

0 commit comments

Comments
 (0)