Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 4612b60

Browse files
committed
Conditionally conform Tensor to ElementaryFunctions.
Conditionally conform `Tensor` to `ElementaryFunctions` where `Scalar: TensorFlowFloatingPoint`.
1 parent c4a7194 commit 4612b60

File tree

3 files changed

+149
-14
lines changed

3 files changed

+149
-14
lines changed

Sources/TensorFlow/Initializers.swift

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -459,7 +459,12 @@ fileprivate extension Tensor where Scalar: BinaryFloatingPoint {
459459
let fanIn = shape[shape.count - 2] * receptiveField
460460
let fanOut = shape[shape.count - 1] * receptiveField
461461
let minusOneToOne = 2 * randomUniform - 1
462-
return sqrt(Scalar(6) / Scalar(fanIn + fanOut)) * minusOneToOne
462+
#if os(macOS) || os(iOS) || os(watchOS) || os(tvOS)
463+
let _sqrt = Darwin.sqrt as (Scalar) -> Scalar
464+
#else
465+
let _sqrt = Glibc.sqrt as (Scalar) -> Scalar
466+
#endif
467+
return _sqrt(Scalar(6) / Scalar(fanIn + fanOut)) * minusOneToOne
463468
}
464469
}
465470

Sources/TensorFlow/Operators/Math.swift

Lines changed: 142 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,148 @@
1515
infix operator .>: ComparisonPrecedence
1616
infix operator .==: ComparisonPrecedence
1717

18-
// `pow` is defined in Darwin/Glibc on `Float` and `Double`, but there doesn't exist a generic
19-
// version for `FloatingPoint`.
20-
// This is a manual definition.
21-
@inlinable
22-
func pow<T: BinaryFloatingPoint>(_ x: T, _ y: T) -> T {
23-
T(pow(Double(x), Double(y)))
24-
}
25-
2618
// TODO:
2719
// - Consider explicit broadcasting for elementwise binary ops when
2820
// scalarization and rank getter are implemented.
2921

22+
//===------------------------------------------------------------------------------------------===//
23+
// Generic elementary functions
24+
//===------------------------------------------------------------------------------------------===//
25+
26+
extension Tensor: ElementaryFunctions where Scalar: TensorFlowFloatingPoint {
27+
/// The square root of `x`.
28+
///
29+
/// For real types, if `x` is negative the result is `.nan`. For complex
30+
/// types there is a branch cut on the negative real axis.
31+
public static func sqrt(_ x: Self) -> Self {
32+
TensorFlow.sqrt(x)
33+
}
34+
35+
/// The cosine of `x`, interpreted as an angle in radians.
36+
public static func cos(_ x: Self) -> Self {
37+
TensorFlow.cos(x)
38+
}
39+
40+
/// The sine of `x`, interpreted as an angle in radians.
41+
public static func sin(_ x: Self) -> Self {
42+
TensorFlow.sin(x)
43+
}
44+
45+
/// The tangent of `x`, interpreted as an angle in radians.
46+
public static func tan(_ x: Self) -> Self {
47+
TensorFlow.tan(x)
48+
}
49+
50+
/// The inverse cosine of `x` in radians.
51+
public static func acos(_ x: Self) -> Self {
52+
TensorFlow.acos(x)
53+
}
54+
55+
/// The inverse sine of `x` in radians.
56+
public static func asin(_ x: Self) -> Self {
57+
TensorFlow.asin(x)
58+
}
59+
60+
/// The inverse tangent of `x` in radians.
61+
public static func atan(_ x: Self) -> Self {
62+
TensorFlow.atan(x)
63+
}
64+
65+
/// The hyperbolic cosine of `x`.
66+
public static func cosh(_ x: Self) -> Self {
67+
TensorFlow.cosh(x)
68+
}
69+
70+
/// The hyperbolic sine of `x`.
71+
public static func sinh(_ x: Self) -> Self {
72+
TensorFlow.sinh(x)
73+
}
74+
75+
/// The hyperbolic tangent of `x`.
76+
public static func tanh(_ x: Self) -> Self {
77+
TensorFlow.tanh(x)
78+
}
79+
80+
/// The inverse hyperbolic cosine of `x`.
81+
public static func acosh(_ x: Self) -> Self {
82+
TensorFlow.acosh(x)
83+
}
84+
85+
/// The inverse hyperbolic sine of `x`.
86+
public static func asinh(_ x: Self) -> Self {
87+
TensorFlow.asinh(x)
88+
}
89+
90+
/// The inverse hyperbolic tangent of `x`.
91+
public static func atanh(_ x: Self) -> Self {
92+
TensorFlow.atanh(x)
93+
}
94+
95+
/// The exponential function applied to `x`, or `e**x`.
96+
public static func exp(_ x: Self) -> Self {
97+
TensorFlow.exp(x)
98+
}
99+
100+
/// Two raised to to power `x`.
101+
public static func exp2(_ x: Self) -> Self {
102+
TensorFlow.exp2(x)
103+
}
104+
105+
/// Ten raised to to power `x`.
106+
public static func exp10(_ x: Self) -> Self {
107+
TensorFlow.exp10(x)
108+
}
109+
110+
/// `exp(x) - 1` evaluated so as to preserve accuracy close to zero.
111+
public static func expm1(_ x: Self) -> Self {
112+
TensorFlow.expm1(x)
113+
}
114+
115+
/// The natural logarithm of `x`.
116+
public static func log(_ x: Self) -> Self {
117+
TensorFlow.log(x)
118+
}
119+
120+
/// The base-two logarithm of `x`.
121+
public static func log2(_ x: Self) -> Self {
122+
TensorFlow.log2(x)
123+
}
124+
125+
/// The base-ten logarithm of `x`.
126+
public static func log10(_ x: Self) -> Self {
127+
TensorFlow.log10(x)
128+
}
129+
130+
/// `log(1 + x)` evaluated so as to preserve accuracy close to zero.
131+
public static func log1p(_ x: Self) -> Self {
132+
TensorFlow.log1p(x)
133+
}
134+
135+
/// `x**y` interpreted as `exp(y * log(x))`
136+
///
137+
/// For real types, if `x` is negative the result is NaN, even if `y` has
138+
/// an integral value. For complex types, there is a branch cut on the
139+
/// negative real axis.
140+
public static func pow(_ x: Self, _ y: Self) -> Self {
141+
TensorFlow.pow(x, y)
142+
}
143+
144+
/// `x` raised to the `n`th power.
145+
///
146+
/// The product of `n` copies of `x`.
147+
public static func pow(_ x: Self, _ n: Int) -> Self {
148+
TensorFlow.pow(x, n)
149+
}
150+
151+
/// The `n`th root of `x`.
152+
///
153+
/// For real types, if `x` is negative and `n` is even, the result is NaN.
154+
/// For complex types, there is a branch cut along the negative real axis.
155+
public static func root(_ x: Self, _ n: Int) -> Self {
156+
TensorFlow.root(x, n)
157+
}
158+
}
159+
30160
//===------------------------------------------------------------------------------------------===//
31161
// Vector Space
32162
//===------------------------------------------------------------------------------------------===//
@@ -1580,7 +1710,7 @@ public extension Tensor where Scalar: TensorFlowFloatingPoint {
15801710
@inlinable
15811711
@differentiable(wrt: self)
15821712
func standardDeviation(squeezingAxes axes: Tensor<Int32>) -> Tensor {
1583-
sqrt(variance(squeezingAxes: axes))
1713+
TensorFlow.sqrt(variance(squeezingAxes: axes))
15841714
}
15851715

15861716
/// Returns the standard deviation of the elements along the specified axes. The reduced
@@ -1591,7 +1721,7 @@ public extension Tensor where Scalar: TensorFlowFloatingPoint {
15911721
@inlinable
15921722
@differentiable(wrt: self)
15931723
func standardDeviation(squeezingAxes axes: [Int]) -> Tensor {
1594-
sqrt(variance(squeezingAxes: axes))
1724+
TensorFlow.sqrt(variance(squeezingAxes: axes))
15951725
}
15961726

15971727
/// Returns the standard deviation of the elements along the specified axes. The reduced
@@ -1625,7 +1755,7 @@ public extension Tensor where Scalar: TensorFlowFloatingPoint {
16251755
@inlinable
16261756
@differentiable(wrt: self)
16271757
func standardDeviation(alongAxes axes: Tensor<Int32>) -> Tensor {
1628-
sqrt(variance(alongAxes: axes))
1758+
TensorFlow.sqrt(variance(alongAxes: axes))
16291759
}
16301760

16311761
/// Returns the standard deviation of the elements along the specified axes. The reduced
@@ -1649,7 +1779,7 @@ public extension Tensor where Scalar: TensorFlowFloatingPoint {
16491779
@inlinable
16501780
@differentiable(wrt: self)
16511781
func standardDeviation(alongAxes axes: Int...) -> Tensor {
1652-
sqrt(variance(alongAxes: axes))
1782+
TensorFlow.sqrt(variance(alongAxes: axes))
16531783
}
16541784
}
16551785

Sources/TensorFlow/Operators/NN.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ public extension Tensor where Scalar: TensorFlowFloatingPoint {
6161
let norm = diff * inv
6262

6363
let dNorm = v * scale
64-
let dVariance = -(dNorm * diff).sum(alongAxes: axis) / 2 * pow(inv, -3)
64+
let dVariance = -(dNorm * diff).sum(alongAxes: axis) / 2 * TensorFlow.pow(inv, -3)
6565
// Note: `dMean` is split into two lines to avoid the "compiler is unable to type-check
6666
// this expression in reasonable time" error.
6767
var dMean = (-dNorm * inv).sum(alongAxes: axis)

0 commit comments

Comments
 (0)