Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 1a7fb90

Browse files
committed
merging master
:
2 parents 05b40de + 20a0923 commit 1a7fb90

File tree

10 files changed

+824
-31
lines changed

10 files changed

+824
-31
lines changed

Sources/TensorFlow/Initializers.swift

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -449,7 +449,7 @@ public extension Tensor where Scalar: BinaryFloatingPoint,
449449
}
450450
}
451451

452-
fileprivate extension Tensor where Scalar: BinaryFloatingPoint {
452+
fileprivate extension Tensor where Scalar: TensorFlowFloatingPoint {
453453
private static func glorot(
454454
fromStandardUniform randomUniform: __shared Tensor<Scalar>,
455455
shape: __shared TensorShape
@@ -459,7 +459,7 @@ fileprivate extension Tensor where Scalar: BinaryFloatingPoint {
459459
let fanIn = shape[shape.count - 2] * receptiveField
460460
let fanOut = shape[shape.count - 1] * receptiveField
461461
let minusOneToOne = 2 * randomUniform - 1
462-
return sqrt(Scalar(6) / Scalar(fanIn + fanOut)) * minusOneToOne
462+
return Scalar.sqrt(Scalar(6) / Scalar(fanIn + fanOut)) * minusOneToOne
463463
}
464464
}
465465

@@ -483,8 +483,7 @@ public extension Tensor where Scalar: TensorFlowFloatingPoint {
483483
}
484484
}
485485

486-
public extension Tensor where Scalar: BinaryFloatingPoint,
487-
Scalar.RawSignificand: FixedWidthInteger {
486+
public extension Tensor where Scalar: TensorFlowFloatingPoint {
488487
/// Performs Glorot uniform initialization for the specified shape, creating a tensor by
489488
/// randomly sampling scalar values from a uniform distribution between `-limit` and `limit`,
490489
/// where limit is `sqrt(6 / (fanIn + fanOut))` and `fanIn`/`fanOut` represent the number of

Sources/TensorFlow/Layers/Recurrent.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ public struct SimpleRNNCell<Scalar: TensorFlowFloatingPoint>: RNNCell, VectorPro
8888
// TODO(TF-507): Revert to `typealias State = Tensor<Scalar>` after
8989
// SR-10697 is fixed.
9090
public struct State: Equatable, Differentiable, VectorProtocol, KeyPathIterable {
91-
public let value: Tensor<Scalar>
91+
public var value: Tensor<Scalar>
9292
public init(_ value: Tensor<Scalar>) {
9393
self.value = value
9494
}

Sources/TensorFlow/Loss.swift

Lines changed: 74 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,35 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15-
/// Computes the mean squared error between predictions and labels.
15+
/// Returns the L1 loss between predictions and expectations.
1616
///
1717
/// - Parameters:
1818
/// - predicted: Predicted outputs from a neural network.
19-
/// - labels: Expected values, i.e. targets, that correspond to the correct output.
19+
/// - expected: Expected values, i.e. targets, that correspond to the correct output.
20+
@differentiable(wrt: predicted)
21+
public func l1Loss<Scalar: TensorFlowFloatingPoint>(
22+
predicted: Tensor<Scalar>, expected: Tensor<Scalar>
23+
) -> Tensor<Scalar> {
24+
return abs(expected - predicted).sum()
25+
}
26+
27+
/// Returns the L2 loss between predictions and expectations.
28+
///
29+
/// - Parameters:
30+
/// - predicted: Predicted outputs from a neural network.
31+
/// - expected: Expected values, i.e. targets, that correspond to the correct output.
32+
@differentiable(wrt: predicted)
33+
public func l2Loss<Scalar: TensorFlowFloatingPoint>(
34+
predicted: Tensor<Scalar>, expected: Tensor<Scalar>
35+
) -> Tensor<Scalar> {
36+
return (expected - predicted).squared().sum()
37+
}
38+
39+
/// Returns the mean squared error between predictions and expectations.
40+
///
41+
/// - Parameters:
42+
/// - predicted: Predicted outputs from a neural network.
43+
/// - expected: Expected values, i.e. targets, that correspond to the correct output.
2044
@differentiable(wrt: predicted)
2145
public func meanSquaredError<Scalar: TensorFlowFloatingPoint>(
2246
predicted: Tensor<Scalar>, expected: Tensor<Scalar>
@@ -41,7 +65,7 @@ public func meanSquaredLogarithmicError<Scalar: TensorFlowFloatingPoint>(
4165
return (logPredicted - logExpected).squared().mean()
4266
}
4367

44-
/// Computes the mean absolute error between predictions and expectations.
68+
/// Returns the mean absolute error between predictions and expectations.
4569
///
4670
/// - Parameters:
4771
/// - predicted: Predicted outputs from a neural network.
@@ -53,6 +77,19 @@ public func meanAbsoluteError<Scalar: TensorFlowFloatingPoint>(
5377
return abs(expected - predicted).mean()
5478
}
5579

80+
/// Returns the mean absolute percentage error between predictions and expectations.
81+
///
82+
/// - Parameters:
83+
/// - predicted: Predicted outputs from a neural network.
84+
/// - expected: Expected values, i.e. targets, that correspond to the correct output.
85+
@differentiable(wrt: predicted)
86+
public func meanAbsolutePercentageError<Scalar: TensorFlowFloatingPoint>(
87+
predicted: Tensor<Scalar>, expected: Tensor<Scalar>
88+
) -> Tensor<Scalar> {
89+
let diff = abs((expected - predicted) / abs(expected))
90+
return 100 * diff.mean()
91+
}
92+
5693
/// Returns the hinge loss between predictions and expectations.
5794
///
5895
/// - Parameters:
@@ -65,6 +102,24 @@ public func hingeLoss<Scalar: TensorFlowFloatingPoint>(
65102
return max(Tensor(1) - expected * predicted, Tensor(0)).mean()
66103
}
67104

105+
/// Returns the cosine similarity between predictions and expectations.
106+
///
107+
/// - Parameters:
108+
/// - predicted: Predicted outputs from a neural network.
109+
/// - expected: Expected values, i.e. targets, that correspond to the correct output.
110+
@differentiable(wrt: (predicted, expected))
111+
public func cosineSimilarity<Scalar: TensorFlowFloatingPoint>(
112+
predicted: Tensor<Scalar>, expected: Tensor<Scalar>
113+
) -> Tensor<Scalar> {
114+
return -(expected * predicted).sum() /
115+
(sqrt(expected.squared().sum()) * sqrt(predicted.squared().sum()))
116+
}
117+
118+
/// Returns the squared hinge loss between predictions and expectations.
119+
///
120+
/// - Parameters:
121+
/// - predicted: Predicted outputs from a neural network.
122+
/// - expected: Expected values, i.e. targets, that correspond to the correct output.
68123
@differentiable(wrt: predicted)
69124
public func squaredHingeLoss<Scalar: TensorFlowFloatingPoint>(
70125
predicted: Tensor<Scalar>, expected: Tensor<Scalar>
@@ -118,7 +173,20 @@ public func poissonLoss<Scalar: TensorFlowFloatingPoint>(
118173
return (predicted - expected * log(predicted)).mean()
119174
}
120175

121-
/// Computes the softmax cross entropy (categorical cross entropy) between logits and labels.
176+
/// Returns the Kullback-Leibler divergence (KL divergence) between between expectations and predictions.
177+
/// Given two distributions `p` and `q`, KL divergence computes `(p * log(p / q)).sum()`.
178+
///
179+
/// - Parameters:
180+
/// - predicted: Predicted outputs from a neural network.
181+
/// - expected: Expected values, i.e. targets, that correspond to the correct output.
182+
@differentiable(wrt: predicted)
183+
public func kullbackLeiblerDivergence<Scalar: TensorFlowFloatingPoint>(
184+
predicted: Tensor<Scalar>, expected: Tensor<Scalar>
185+
) -> Tensor<Scalar> {
186+
return (expected * log(expected / predicted)).sum()
187+
}
188+
189+
/// Returns the softmax cross entropy (categorical cross entropy) between logits and labels.
122190
///
123191
/// - Parameters:
124192
/// - logits: One-hot encoded outputs from a neural network.
@@ -139,7 +207,7 @@ func _vjpSoftmaxCrossEntropy<Scalar: TensorFlowFloatingPoint>(
139207
return (loss.mean(), { v in (v / batchSize) * grad })
140208
}
141209

142-
/// Computes the softmax cross entropy (categorical cross entropy) between logits and labels.
210+
/// Returns the softmax cross entropy (categorical cross entropy) between logits and labels.
143211
///
144212
/// - Parameters:
145213
/// - logits: Unscaled log probabilities from a neural network.
@@ -161,7 +229,7 @@ func _vjpSoftmaxCrossEntropy<Scalar: TensorFlowFloatingPoint>(
161229
return (loss.mean(), { v in v / batchSize * grad })
162230
}
163231

164-
/// Computes the sigmoid cross entropy (binary cross entropy) between logits and labels.
232+
/// Returns the sigmoid cross entropy (binary cross entropy) between logits and labels.
165233
///
166234
/// The reduction is reduced over all elements. If reduced over batch size is intended, please
167235
/// consider to scale the loss.

Sources/TensorFlow/Operators/Math.swift

Lines changed: 143 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,148 @@
1515
infix operator .>: ComparisonPrecedence
1616
infix operator .==: ComparisonPrecedence
1717

18-
// `pow` is defined in Darwin/Glibc on `Float` and `Double`, but there doesn't exist a generic
19-
// version for `FloatingPoint`.
20-
// This is a manual definition.
21-
@inlinable
22-
func pow<T: BinaryFloatingPoint>(_ x: T, _ y: T) -> T {
23-
T(pow(Double(x), Double(y)))
24-
}
25-
2618
// TODO:
2719
// - Consider explicit broadcasting for elementwise binary ops when
2820
// scalarization and rank getter are implemented.
2921

22+
//===------------------------------------------------------------------------------------------===//
23+
// Generic elementary functions
24+
//===------------------------------------------------------------------------------------------===//
25+
26+
extension Tensor: ElementaryFunctions where Scalar: TensorFlowFloatingPoint {
27+
/// The square root of `x`.
28+
///
29+
/// For real types, if `x` is negative the result is `.nan`. For complex
30+
/// types there is a branch cut on the negative real axis.
31+
public static func sqrt(_ x: Self) -> Self {
32+
TensorFlow.sqrt(x)
33+
}
34+
35+
/// The cosine of `x`, interpreted as an angle in radians.
36+
public static func cos(_ x: Self) -> Self {
37+
TensorFlow.cos(x)
38+
}
39+
40+
/// The sine of `x`, interpreted as an angle in radians.
41+
public static func sin(_ x: Self) -> Self {
42+
TensorFlow.sin(x)
43+
}
44+
45+
/// The tangent of `x`, interpreted as an angle in radians.
46+
public static func tan(_ x: Self) -> Self {
47+
TensorFlow.tan(x)
48+
}
49+
50+
/// The inverse cosine of `x` in radians.
51+
public static func acos(_ x: Self) -> Self {
52+
TensorFlow.acos(x)
53+
}
54+
55+
/// The inverse sine of `x` in radians.
56+
public static func asin(_ x: Self) -> Self {
57+
TensorFlow.asin(x)
58+
}
59+
60+
/// The inverse tangent of `x` in radians.
61+
public static func atan(_ x: Self) -> Self {
62+
TensorFlow.atan(x)
63+
}
64+
65+
/// The hyperbolic cosine of `x`.
66+
public static func cosh(_ x: Self) -> Self {
67+
TensorFlow.cosh(x)
68+
}
69+
70+
/// The hyperbolic sine of `x`.
71+
public static func sinh(_ x: Self) -> Self {
72+
TensorFlow.sinh(x)
73+
}
74+
75+
/// The hyperbolic tangent of `x`.
76+
public static func tanh(_ x: Self) -> Self {
77+
TensorFlow.tanh(x)
78+
}
79+
80+
/// The inverse hyperbolic cosine of `x`.
81+
public static func acosh(_ x: Self) -> Self {
82+
TensorFlow.acosh(x)
83+
}
84+
85+
/// The inverse hyperbolic sine of `x`.
86+
public static func asinh(_ x: Self) -> Self {
87+
TensorFlow.asinh(x)
88+
}
89+
90+
/// The inverse hyperbolic tangent of `x`.
91+
public static func atanh(_ x: Self) -> Self {
92+
TensorFlow.atanh(x)
93+
}
94+
95+
/// The exponential function applied to `x`, or `e**x`.
96+
public static func exp(_ x: Self) -> Self {
97+
TensorFlow.exp(x)
98+
}
99+
100+
/// Two raised to to power `x`.
101+
public static func exp2(_ x: Self) -> Self {
102+
TensorFlow.exp2(x)
103+
}
104+
105+
/// Ten raised to to power `x`.
106+
public static func exp10(_ x: Self) -> Self {
107+
TensorFlow.exp10(x)
108+
}
109+
110+
/// `exp(x) - 1` evaluated so as to preserve accuracy close to zero.
111+
public static func expm1(_ x: Self) -> Self {
112+
TensorFlow.expm1(x)
113+
}
114+
115+
/// The natural logarithm of `x`.
116+
public static func log(_ x: Self) -> Self {
117+
TensorFlow.log(x)
118+
}
119+
120+
/// The base-two logarithm of `x`.
121+
public static func log2(_ x: Self) -> Self {
122+
TensorFlow.log2(x)
123+
}
124+
125+
/// The base-ten logarithm of `x`.
126+
public static func log10(_ x: Self) -> Self {
127+
TensorFlow.log10(x)
128+
}
129+
130+
/// `log(1 + x)` evaluated so as to preserve accuracy close to zero.
131+
public static func log1p(_ x: Self) -> Self {
132+
TensorFlow.log1p(x)
133+
}
134+
135+
/// `exp(y log(x))` computed without loss of intermediate precision.
136+
///
137+
/// For real types, if `x` is negative the result is NaN, even if `y` has
138+
/// an integral value. For complex types, there is a branch cut on the
139+
/// negative real axis.
140+
public static func pow(_ x: Self, _ y: Self) -> Self {
141+
TensorFlow.pow(x, y)
142+
}
143+
144+
/// `x` raised to the `n`th power.
145+
///
146+
/// The product of `n` copies of `x`.
147+
public static func pow(_ x: Self, _ n: Int) -> Self {
148+
TensorFlow.pow(x, n)
149+
}
150+
151+
/// The `n`th root of `x`.
152+
///
153+
/// For real types, if `x` is negative and `n` is even, the result is NaN.
154+
/// For complex types, there is a branch cut along the negative real axis.
155+
public static func root(_ x: Self, _ n: Int) -> Self {
156+
TensorFlow.root(x, n)
157+
}
158+
}
159+
30160
//===------------------------------------------------------------------------------------------===//
31161
// Vector Space
32162
//===------------------------------------------------------------------------------------------===//
@@ -876,7 +1006,7 @@ public func pow<T: TensorFlowFloatingPoint>(_ x: Tensor<T>, _ n: Int) -> Tensor<
8761006
@inlinable
8771007
// @differentiable
8781008
public func root<T: TensorFlowFloatingPoint>(_ x: Tensor<T>, _ n: Int) -> Tensor<T> {
879-
pow(x, Tensor(T(1) / T(n)))
1009+
sign(x) * pow(abs(x), Tensor(T(1) / T(n)))
8801010
}
8811011

8821012
/// Computes the element-wise maximum of two tensors.
@@ -1580,7 +1710,7 @@ public extension Tensor where Scalar: TensorFlowFloatingPoint {
15801710
@inlinable
15811711
@differentiable(wrt: self)
15821712
func standardDeviation(squeezingAxes axes: Tensor<Int32>) -> Tensor {
1583-
sqrt(variance(squeezingAxes: axes))
1713+
TensorFlow.sqrt(variance(squeezingAxes: axes))
15841714
}
15851715

15861716
/// Returns the standard deviation of the elements along the specified axes. The reduced
@@ -1591,7 +1721,7 @@ public extension Tensor where Scalar: TensorFlowFloatingPoint {
15911721
@inlinable
15921722
@differentiable(wrt: self)
15931723
func standardDeviation(squeezingAxes axes: [Int]) -> Tensor {
1594-
sqrt(variance(squeezingAxes: axes))
1724+
TensorFlow.sqrt(variance(squeezingAxes: axes))
15951725
}
15961726

15971727
/// Returns the standard deviation of the elements along the specified axes. The reduced
@@ -1625,7 +1755,7 @@ public extension Tensor where Scalar: TensorFlowFloatingPoint {
16251755
@inlinable
16261756
@differentiable(wrt: self)
16271757
func standardDeviation(alongAxes axes: Tensor<Int32>) -> Tensor {
1628-
sqrt(variance(alongAxes: axes))
1758+
TensorFlow.sqrt(variance(alongAxes: axes))
16291759
}
16301760

16311761
/// Returns the standard deviation of the elements along the specified axes. The reduced
@@ -1649,7 +1779,7 @@ public extension Tensor where Scalar: TensorFlowFloatingPoint {
16491779
@inlinable
16501780
@differentiable(wrt: self)
16511781
func standardDeviation(alongAxes axes: Int...) -> Tensor {
1652-
sqrt(variance(alongAxes: axes))
1782+
TensorFlow.sqrt(variance(alongAxes: axes))
16531783
}
16541784
}
16551785

Sources/TensorFlow/Operators/NN.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ public extension Tensor where Scalar: TensorFlowFloatingPoint {
6161
let norm = diff * inv
6262

6363
let dNorm = v * scale
64-
let dVariance = -(dNorm * diff).sum(alongAxes: axis) / 2 * pow(inv, -3)
64+
let dVariance = -(dNorm * diff).sum(alongAxes: axis) / 2 * TensorFlow.pow(inv, -3)
6565
// Note: `dMean` is split into two lines to avoid the "compiler is unable to type-check
6666
// this expression in reasonable time" error.
6767
var dMean = (-dNorm * inv).sum(alongAxes: axis)

Sources/TensorFlow/Optimizer.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,7 @@ public class RiemannSGD<Model: Layer, Scalar: FloatingPoint>: Optimizer
291291

292292
public func update(_ model: inout Model.AllDifferentiableVariables,
293293
along direction: Model.TangentVector) {
294-
model = model.moved(along: learningRate * (.zero - direction))
294+
model.move(along: learningRate * (.zero - direction))
295295
}
296296
}
297297

Tests/TensorFlowTests/Helpers.swift

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@ internal func assertEqual<T: TensorFlowFloatingPoint>(
2121
) {
2222
for (x, y) in zip(x, y) {
2323
if x.isNaN || y.isNaN {
24-
XCTAssertTrue(x.isNaN && y.isNaN, message, file: file, line: line)
24+
XCTAssertTrue(x.isNaN && y.isNaN,
25+
"\(x) is not equal to \(y) - \(message)",
26+
file: file, line: line)
2527
continue
2628
}
2729
XCTAssertEqual(x, y, accuracy: accuracy, message, file: file, line: line)

0 commit comments

Comments
 (0)