Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 28e7ad2

Browse files
authored
Merge branch 'master' into lc
2 parents c2d7c6c + d42e80f commit 28e7ad2

File tree

6 files changed

+124
-9
lines changed

6 files changed

+124
-9
lines changed

Sources/TensorFlow/Layers/Recurrent.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ public struct SimpleRNNCell<Scalar: TensorFlowFloatingPoint>: RNNCell, VectorPro
8888
// TODO(TF-507): Revert to `typealias State = Tensor<Scalar>` after
8989
// SR-10697 is fixed.
9090
public struct State: Equatable, Differentiable, VectorProtocol, KeyPathIterable {
91-
public let value: Tensor<Scalar>
91+
public var value: Tensor<Scalar>
9292
public init(_ value: Tensor<Scalar>) {
9393
self.value = value
9494
}

Sources/TensorFlow/Loss.swift

Lines changed: 72 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,31 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15-
/// Computes the mean squared error between predictions and labels.
15+
/// Returns the L1 loss between predictions and labels.
16+
///
17+
/// - Parameters:
18+
/// - predicted: Predicted outputs from a neural network.
19+
/// - labels: Expected values, i.e. targets, that correspond to the correct output.
20+
@differentiable(wrt: predicted)
21+
public func l1Loss<Scalar: TensorFlowFloatingPoint>(
22+
predicted: Tensor<Scalar>, expected: Tensor<Scalar>
23+
) -> Tensor<Scalar> {
24+
return abs(expected - predicted).sum()
25+
}
26+
27+
/// Returns the L2 loss between predictions and labels.
28+
///
29+
/// - Parameters:
30+
/// - predicted: Predicted outputs from a neural network.
31+
/// - labels: Expected values, i.e. targets, that correspond to the correct output.
32+
@differentiable(wrt: predicted)
33+
public func l2Loss<Scalar: TensorFlowFloatingPoint>(
34+
predicted: Tensor<Scalar>, expected: Tensor<Scalar>
35+
) -> Tensor<Scalar> {
36+
return (expected - predicted).squared().sum()
37+
}
38+
39+
/// Returns the mean squared error between predictions and labels.
1640
///
1741
/// - Parameters:
1842
/// - predicted: Predicted outputs from a neural network.
@@ -41,7 +65,7 @@ public func meanSquaredLogarithmicError<Scalar: TensorFlowFloatingPoint>(
4165
return (logPredicted - logExpected).squared().mean()
4266
}
4367

44-
/// Computes the mean absolute error between predictions and expectations.
68+
/// Returns the mean absolute error between predictions and expectations.
4569
///
4670
/// - Parameters:
4771
/// - predicted: Predicted outputs from a neural network.
@@ -53,6 +77,19 @@ public func meanAbsoluteError<Scalar: TensorFlowFloatingPoint>(
5377
return abs(expected - predicted).mean()
5478
}
5579

80+
/// Returns the mean absolute percentage error between predictions and expectations.
81+
///
82+
/// - Parameters:
83+
/// - predicted: Predicted outputs from a neural network.
84+
/// - expected: Expected values, i.e. targets, that correspond to the correct output.
85+
@differentiable(wrt: predicted)
86+
public func meanAbsolutePercentageError<Scalar: TensorFlowFloatingPoint>(
87+
predicted: Tensor<Scalar>, expected: Tensor<Scalar>
88+
) -> Tensor<Scalar> {
89+
let diff = abs((expected - predicted) / abs(expected))
90+
return 100 * diff.mean()
91+
}
92+
5693
/// Returns the hinge loss between predictions and expectations.
5794
///
5895
/// - Parameters:
@@ -65,6 +102,24 @@ public func hingeLoss<Scalar: TensorFlowFloatingPoint>(
65102
return max(Tensor(1) - expected * predicted, Tensor(0)).mean()
66103
}
67104

105+
/// Returns the cosine similarity between predictions and expectations.
106+
///
107+
/// - Parameters:
108+
/// - predicted: Predicted outputs from a neural network.
109+
/// - expected: Expected values, i.e. targets, that correspond to the correct output.
110+
@differentiable(wrt: (predicted, expected))
111+
public func cosineSimilarity<Scalar: TensorFlowFloatingPoint>(
112+
predicted: Tensor<Scalar>, expected: Tensor<Scalar>
113+
) -> Tensor<Scalar> {
114+
return -(expected * predicted).sum() /
115+
(sqrt(expected.squared().sum()) * sqrt(predicted.squared().sum()))
116+
}
117+
118+
/// Returns the squared hinge loss between predictions and expectations.
119+
///
120+
/// - Parameters:
121+
/// - predicted: Predicted outputs from a neural network.
122+
/// - expected: Expected values, i.e. targets, that correspond to the correct output.
68123
@differentiable(wrt: predicted)
69124
public func squaredHingeLoss<Scalar: TensorFlowFloatingPoint>(
70125
predicted: Tensor<Scalar>, expected: Tensor<Scalar>
@@ -119,7 +174,19 @@ public func poissonLoss<Scalar: TensorFlowFloatingPoint>(
119174
return (predicted - expected * log(predicted)).mean()
120175
}
121176

122-
/// Computes the softmax cross entropy (categorical cross entropy) between logits and labels.
177+
/// Returns the Kullback-Leibler divergence between predictions and expectations.
178+
///
179+
/// - Parameters:
180+
/// - predicted: Predicted outputs from a neural network.
181+
/// - expected: Expected values, i.e. targets, that correspond to the correct output.
182+
@differentiable(wrt: predicted)
183+
public func kullbackLeiblerDivergence<Scalar: TensorFlowFloatingPoint>(
184+
predicted: Tensor<Scalar>, expected: Tensor<Scalar>
185+
) -> Tensor<Scalar> {
186+
return (expected * log(expected / predicted)).sum()
187+
}
188+
189+
/// Returns the softmax cross entropy (categorical cross entropy) between logits and labels.
123190
///
124191
/// - Parameters:
125192
/// - logits: One-hot encoded outputs from a neural network.
@@ -140,7 +207,7 @@ func _vjpSoftmaxCrossEntropy<Scalar: TensorFlowFloatingPoint>(
140207
return (loss.mean(), { v in (v / batchSize) * grad })
141208
}
142209

143-
/// Computes the softmax cross entropy (categorical cross entropy) between logits and labels.
210+
/// Returns the softmax cross entropy (categorical cross entropy) between logits and labels.
144211
///
145212
/// - Parameters:
146213
/// - logits: Unscaled log probabilities from a neural network.
@@ -162,7 +229,7 @@ func _vjpSoftmaxCrossEntropy<Scalar: TensorFlowFloatingPoint>(
162229
return (loss.mean(), { v in v / batchSize * grad })
163230
}
164231

165-
/// Computes the sigmoid cross entropy (binary cross entropy) between logits and labels.
232+
/// Returns the sigmoid cross entropy (binary cross entropy) between logits and labels.
166233
///
167234
/// The reduction is reduced over all elements. If reduced over batch size is intended, please
168235
/// consider to scale the loss.

Sources/TensorFlow/Operators/Math.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -876,7 +876,7 @@ public func pow<T: TensorFlowFloatingPoint>(_ x: Tensor<T>, _ n: Int) -> Tensor<
876876
@inlinable
877877
// @differentiable
878878
public func root<T: TensorFlowFloatingPoint>(_ x: Tensor<T>, _ n: Int) -> Tensor<T> {
879-
pow(x, Tensor(T(1) / T(n)))
879+
sign(x) * pow(abs(x), Tensor(T(1) / T(n)))
880880
}
881881

882882
/// Computes the element-wise maximum of two tensors.

Sources/TensorFlow/Optimizer.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,7 @@ public class RiemannSGD<Model: Layer, Scalar: FloatingPoint>: Optimizer
291291

292292
public func update(_ model: inout Model.AllDifferentiableVariables,
293293
along direction: Model.TangentVector) {
294-
model = model.moved(along: learningRate * (.zero - direction))
294+
model.move(along: learningRate * (.zero - direction))
295295
}
296296
}
297297

Tests/TensorFlowTests/Helpers.swift

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@ internal func assertEqual<T: TensorFlowFloatingPoint>(
2121
) {
2222
for (x, y) in zip(x, y) {
2323
if x.isNaN || y.isNaN {
24-
XCTAssertTrue(x.isNaN && y.isNaN, message, file: file, line: line)
24+
XCTAssertTrue(x.isNaN && y.isNaN,
25+
"\(x) is not equal to \(y) - \(message)",
26+
file: file, line: line)
2527
continue
2628
}
2729
XCTAssertEqual(x, y, accuracy: accuracy, message, file: file, line: line)

Tests/TensorFlowTests/LossTests.swift

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,22 @@ import XCTest
1616
@testable import TensorFlow
1717

1818
final class LossTests: XCTestCase {
19+
func testL1Loss() {
20+
let predicted = Tensor<Float>([1, 2, 3, 4])
21+
let expected = Tensor<Float>([0.1, 0.2, 0.3, 0.4])
22+
let loss = l1Loss(predicted: predicted, expected: expected)
23+
let expectedLoss: Float = 9.0
24+
assertElementsEqual(expected: Tensor(expectedLoss), actual: loss)
25+
}
26+
27+
func testL2Loss() {
28+
let predicted = Tensor<Float>([1, 2, 3, 4])
29+
let expected = Tensor<Float>([0.5, 1.5, 2.5, 3.5])
30+
let loss = l2Loss(predicted: predicted, expected: expected)
31+
let expectedLoss: Float = 1.0
32+
assertElementsEqual(expected: Tensor(expectedLoss), actual: loss)
33+
}
34+
1935
func testMeanSquaredErrorLoss() {
2036
let predicted = Tensor<Float>(shape: [2, 4], scalars: [1, 2, 3, 4, 5, 6, 7, 8])
2137
let expected = Tensor<Float>(
@@ -49,6 +65,15 @@ final class LossTests: XCTestCase {
4965
assertElementsEqual(expected: Tensor(expectedLoss), actual: loss)
5066
}
5167

68+
func testMeanAbsolutePercentageError() {
69+
let predicted = Tensor<Float>([1, 2, 3, 4, 5, 6, 7, 8])
70+
let expected = Tensor<Float>([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8])
71+
72+
let loss = meanAbsolutePercentageError(predicted: predicted, expected: expected)
73+
let expectedLoss: Float = 900.0
74+
assertElementsEqual(expected: Tensor(expectedLoss), actual: loss)
75+
}
76+
5277
func testMeanSquaredErrorGrad() {
5378
let predicted = Tensor<Float>(shape: [2, 4], scalars: [1, 2, 3, 4, 5, 6, 7, 8])
5479
let expected = Tensor<Float>(
@@ -78,6 +103,14 @@ final class LossTests: XCTestCase {
78103
let expectedLoss: Float = 0.225
79104
assertElementsEqual(expected: Tensor(expectedLoss), actual: loss)
80105
}
106+
107+
func testCosineSimilarityLoss() {
108+
let predicted = Tensor<Float>([1, 2, 3, 4, 5, 6, 7, 8])
109+
let expected = Tensor<Float>([0.5, 1, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0])
110+
let loss = cosineSimilarity(predicted: predicted, expected: expected)
111+
let expectedLoss: Float = -1.0
112+
assertElementsEqual(expected: Tensor(expectedLoss), actual: loss)
113+
}
81114

82115
func testSquaredHingeLoss() {
83116
let predicted = Tensor<Float>([1, 2, 3, 4, 5, 6, 7, 8])
@@ -112,6 +145,14 @@ final class LossTests: XCTestCase {
112145
assertElementsEqual(expected: Tensor(expectedLoss), actual: loss)
113146
}
114147

148+
func testKullbackLeiblerDivergence() {
149+
let predicted = Tensor<Float>([0.2, 0.3, 0.4])
150+
let expected = Tensor<Float>([1.0, 4.0, 3.0])
151+
let loss = kullbackLeiblerDivergence(predicted: predicted, expected: expected)
152+
let expectedLoss: Float = 18.015217
153+
assertElementsEqual(expected: Tensor(expectedLoss), actual: loss)
154+
}
155+
115156
func testSoftmaxCrossEntropyWithProbabilitiesLoss() {
116157
let logits = Tensor<Float>(shape: [2, 4], scalars: [1, 2, 3, 4, 5, 6, 7, 8])
117158
let labels = Tensor<Float>(
@@ -202,12 +243,17 @@ final class LossTests: XCTestCase {
202243
}
203244

204245
static var allTests = [
246+
("testL1Loss", testL1Loss),
247+
("testL2Loss", testL2Loss),
205248
("testMeanSquaredErrorLoss", testMeanSquaredErrorLoss),
206249
("testMeanSquaredErrorGrad", testMeanSquaredErrorGrad),
207250
("testMeanSquaredLogarithmicError", testMeanSquaredLogarithmicError),
208251
("testMeanAbsoluteError", testMeanAbsoluteError),
252+
("testMeanAbsolutePercentageError", testMeanAbsolutePercentageError),
209253
("testHingeLoss", testHingeLoss),
254+
("testKullbackLeiblerDivergence", testKullbackLeiblerDivergence),
210255
("testCategoricalHingeLoss", testCategoricalHingeLoss),
256+
("testCosineSimilarityLoss", testCosineSimilarityLoss),
211257
("testSquaredHingeLoss", testSquaredHingeLoss),
212258
("testPoissonLoss",testPoissonLoss),
213259
("testLogcoshLoss", testLogcoshLoss),

0 commit comments

Comments
 (0)