Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit b882d0e

Browse files
rxweidan-zheng
authored andcommitted
[WIP] Try conforming to ElementaryFunctions.
1 parent 1e30f3a commit b882d0e

File tree

1 file changed

+144
-3
lines changed

1 file changed

+144
-3
lines changed

Sources/TensorFlow/Operators/Math.swift

Lines changed: 144 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,150 @@ func pow<T: BinaryFloatingPoint>(_ x: T, _ y: T) -> T {
2323
return T(pow(Double(x), Double(y)))
2424
}
2525

26-
// TODO:
27-
// - Consider explicit broadcasting for elementwise binary ops when
28-
// scalarization and rank getter are implemented.
26+
//===------------------------------------------------------------------------------------------===//
27+
// Generic elementary functions
28+
//===------------------------------------------------------------------------------------------===//
29+
30+
extension Tensor: ElementaryFunctions where Scalar: FloatingPoint {
31+
/// The square root of `x`.
32+
///
33+
/// For real types, if the argument is negative, either the result is NaN
34+
/// or a precondition failure occurs. For complex types, this function has
35+
/// a branch cut along the negative real axis.
36+
@differentiable
37+
public static func sqrt(_ x: Self) -> Self {
38+
return TensorFlow.sqrt(x)
39+
}
40+
41+
/// The cosine of `x`.
42+
///
43+
/// For real types, `x` is interpreted as an angle measured in radians.
44+
@differentiable
45+
public static func cos(_ x: Self) -> Self {
46+
return TensorFlow.cos(x)
47+
}
48+
49+
/// The sine of `x`.
50+
///
51+
/// For real types, `x` is interpreted as an angle measured in radians.
52+
@differentiable
53+
public static func sin(_ x: Self) -> Self {
54+
return TensorFlow.sin(x)
55+
}
56+
57+
/// The tangent of `x`.
58+
@differentiable
59+
public static func tan(_ x: Self) -> Self {
60+
return TensorFlow.tan(x)
61+
}
62+
63+
/// The acos function.
64+
public static func acos(_ x: Self) -> Self {
65+
fatalError("Unimplemented")
66+
}
67+
68+
/// The asin function.
69+
public static func asin(_ x: Self) -> Self {
70+
fatalError("Unimplemented")
71+
}
72+
73+
/// The atan function.
74+
public static func atan(_ x: Self) -> Self {
75+
fatalError("Unimplemented")
76+
}
77+
78+
/// The cosh function.
79+
public static func cosh(_ x: Self) -> Self {
80+
TensorFlow.cosh(x)
81+
}
82+
83+
/// The sinh function.
84+
public static func sinh(_ x: Self) -> Self {
85+
TensorFlow.sinh(x)
86+
}
87+
88+
/// The tanh function.
89+
public static func tanh(_ x: Self) -> Self {
90+
TensorFlow.tanh(x)
91+
}
92+
93+
/// The acosh function.
94+
public static func acosh(_ x: Self) -> Self {
95+
fatalError("Unimplemented")
96+
}
97+
98+
/// The asinh function.
99+
public static func asinh(_ x: Self) -> Self {
100+
fatalError("Unimplemented")
101+
}
102+
103+
/// The atanh function.
104+
public static func atanh(_ x: Self) -> Self {
105+
fatalError("Unimplemented")
106+
}
107+
108+
/// The exp function.
109+
public static func exp(_ x: Self) -> Self {
110+
TensorFlow.exp(x)
111+
}
112+
113+
/// The exp2 function.
114+
public static func exp2(_ x: Self) -> Self {
115+
fatalError("Unimplemented")
116+
}
117+
118+
/// The exp10 function.
119+
public static func exp10(_ x: Self) -> Self {
120+
fatalError("Unimplemented")
121+
}
122+
123+
/// The expm1 function.
124+
public static func expm1(_ x: Self) -> Self {
125+
fatalError("Unimplemented")
126+
}
127+
128+
/// The log function.
129+
public static func log(_ x: Self) -> Self {
130+
fatalError("Unimplemented")
131+
}
132+
133+
/// The log2 function.
134+
public static func log2(_ x: Self) -> Self {
135+
fatalError("Unimplemented")
136+
}
137+
138+
/// The log10 function.
139+
public static func log10(_ x: Self) -> Self {
140+
fatalError("Unimplemented")
141+
}
142+
143+
/// The log1p function.
144+
public static func log1p(_ x: Self) -> Self {
145+
fatalError("Unimplemented")
146+
}
147+
148+
/// `exp(y log(x))` computed without loss of intermediate precision.
149+
///
150+
/// For real types, if `x` is negative the result is NaN, even if `y` has
151+
/// an integral value. For complex types, there is a branch cut on the
152+
/// negative real axis.
153+
public static func pow(_ x: Self, _ y: Self) -> Self {
154+
fatalError("Unimplemented")
155+
}
156+
157+
/// `x` raised to the `n`th power.
158+
public static func pow(_ x: Self, _ n: Int) -> Self {
159+
fatalError("Unimplemented")
160+
}
161+
162+
/// The `n`th root of `x`.
163+
///
164+
/// For real types, if `x` is negative and `n` is even, the result is NaN.
165+
/// For complex types, there is a branch cut along the negative real axis.
166+
public static func root(_ x: Self, _ n: Int) -> Self {
167+
fatalError("Unimplemented")
168+
}
169+
}
29170

30171
//===------------------------------------------------------------------------------------------===//
31172
// Vector Space

0 commit comments

Comments
 (0)