Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 5e24305

Browse files
jon-towrxwei
authored andcommitted
Add @differentiable to all Tensor functions conforming to ElemenartyFunctions (#369)
1 parent 4539792 commit 5e24305

File tree

1 file changed

+23
-2
lines changed

1 file changed

+23
-2
lines changed

Sources/TensorFlow/Operators/Math.swift

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,106 +44,127 @@ extension Tensor: ElementaryFunctions where Scalar: TensorFlowFloatingPoint {
4444
///
4545
/// For real types, if `x` is negative the result is `.nan`. For complex
4646
/// types there is a branch cut on the negative real axis.
47+
@differentiable
4748
public static func sqrt(_ x: Self) -> Self {
4849
TensorFlow.sqrt(x)
4950
}
5051

5152
/// The cosine of `x`, interpreted as an angle in radians.
53+
@differentiable
5254
public static func cos(_ x: Self) -> Self {
5355
TensorFlow.cos(x)
5456
}
5557

5658
/// The sine of `x`, interpreted as an angle in radians.
59+
@differentiable
5760
public static func sin(_ x: Self) -> Self {
5861
TensorFlow.sin(x)
5962
}
6063

6164
/// The tangent of `x`, interpreted as an angle in radians.
65+
@differentiable
6266
public static func tan(_ x: Self) -> Self {
6367
TensorFlow.tan(x)
6468
}
6569

6670
/// The inverse cosine of `x` in radians.
71+
@differentiable
6772
public static func acos(_ x: Self) -> Self {
6873
TensorFlow.acos(x)
6974
}
7075

7176
/// The inverse sine of `x` in radians.
77+
@differentiable
7278
public static func asin(_ x: Self) -> Self {
7379
TensorFlow.asin(x)
7480
}
7581

7682
/// The inverse tangent of `x` in radians.
83+
@differentiable
7784
public static func atan(_ x: Self) -> Self {
7885
TensorFlow.atan(x)
7986
}
8087

8188
/// The hyperbolic cosine of `x`.
89+
@differentiable
8290
public static func cosh(_ x: Self) -> Self {
8391
TensorFlow.cosh(x)
8492
}
8593

8694
/// The hyperbolic sine of `x`.
95+
@differentiable
8796
public static func sinh(_ x: Self) -> Self {
8897
TensorFlow.sinh(x)
8998
}
9099

91100
/// The hyperbolic tangent of `x`.
101+
@differentiable
92102
public static func tanh(_ x: Self) -> Self {
93103
TensorFlow.tanh(x)
94104
}
95105

96106
/// The inverse hyperbolic cosine of `x`.
107+
@differentiable
97108
public static func acosh(_ x: Self) -> Self {
98109
TensorFlow.acosh(x)
99110
}
100111

101112
/// The inverse hyperbolic sine of `x`.
113+
@differentiable
102114
public static func asinh(_ x: Self) -> Self {
103115
TensorFlow.asinh(x)
104116
}
105117

106118
/// The inverse hyperbolic tangent of `x`.
119+
@differentiable
107120
public static func atanh(_ x: Self) -> Self {
108121
TensorFlow.atanh(x)
109122
}
110123

111124
/// The exponential function applied to `x`, or `e**x`.
125+
@differentiable
112126
public static func exp(_ x: Self) -> Self {
113127
TensorFlow.exp(x)
114128
}
115129

116130
/// Two raised to to power `x`.
131+
@differentiable
117132
public static func exp2(_ x: Self) -> Self {
118133
TensorFlow.exp2(x)
119134
}
120135

121136
/// Ten raised to to power `x`.
137+
@differentiable
122138
public static func exp10(_ x: Self) -> Self {
123139
TensorFlow.exp10(x)
124140
}
125141

126142
/// `exp(x) - 1` evaluated so as to preserve accuracy close to zero.
143+
@differentiable
127144
public static func expm1(_ x: Self) -> Self {
128145
TensorFlow.expm1(x)
129146
}
130147

131148
/// The natural logarithm of `x`.
149+
@differentiable
132150
public static func log(_ x: Self) -> Self {
133151
TensorFlow.log(x)
134152
}
135153

136154
/// The base-two logarithm of `x`.
155+
@differentiable
137156
public static func log2(_ x: Self) -> Self {
138157
TensorFlow.log2(x)
139158
}
140159

141160
/// The base-ten logarithm of `x`.
161+
@differentiable
142162
public static func log10(_ x: Self) -> Self {
143163
TensorFlow.log10(x)
144164
}
145165

146166
/// `log(1 + x)` evaluated so as to preserve accuracy close to zero.
167+
@differentiable
147168
public static func log1p(_ x: Self) -> Self {
148169
TensorFlow.log1p(x)
149170
}
@@ -930,14 +951,14 @@ internal func _vjpExp<T: TensorFlowFloatingPoint>(
930951

931952
/// Returns two raised to the power of the specified tensor element-wise.
932953
@inlinable
933-
// @differentiable
954+
@differentiable
934955
public func exp2<T: TensorFlowFloatingPoint>(_ x: Tensor<T>) -> Tensor<T> {
935956
pow(2, x)
936957
}
937958

938959
/// Returns ten raised to the power of the specified tensor element-wise.
939960
@inlinable
940-
// @differentiable
961+
@differentiable
941962
public func exp10<T: TensorFlowFloatingPoint>(_ x: Tensor<T>) -> Tensor<T> {
942963
pow(10, x)
943964
}

0 commit comments

Comments
 (0)