Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 65b320c

Browse files
committed
merging master
2 parents ce2a464 + a803823 commit 65b320c

File tree

2 files changed

+23
-3
lines changed

2 files changed

+23
-3
lines changed

Sources/TensorFlow/Loss.swift

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,16 +48,25 @@ public func hingeLoss<Scalar: TensorFlowFloatingPoint>(
4848
return max(Tensor(1) - expected * predicted, Tensor(0)).mean()
4949
}
5050

51-
/// Returns the squared hinge loss between predictions and expectations.
51+
@differentiable(wrt: predicted)
52+
public func squaredHingeLoss<Scalar: TensorFlowFloatingPoint>(
53+
predicted: Tensor<Scalar>, expected: Tensor<Scalar>
54+
) -> Tensor<Scalar> {
55+
return (max(Tensor(1) - expected * predicted, Tensor(0))).squared().mean()
56+
}
57+
58+
/// Returns the hinge loss between predictions and expectations.
5259
///
5360
/// - Parameters:
5461
/// - predicted: Predicted outputs from a neural network.
5562
/// - expected: Expected values, i.e. targets, that correspond to the correct output.
5663
@differentiable(wrt: predicted)
57-
public func squaredHingeLoss<Scalar: TensorFlowFloatingPoint>(
64+
public func categoricalHingeLoss<Scalar: TensorFlowFloatingPoint>(
5865
predicted: Tensor<Scalar>, expected: Tensor<Scalar>
5966
) -> Tensor<Scalar> {
60-
return (max(Tensor(1) - expected * predicted, Tensor(0))).squared().mean()
67+
let positive = (expected * predicted).sum()
68+
let negative = ((Tensor(1) - expected) * predicted).max()
69+
return max(Tensor(0), negative - positive + Tensor(1))
6170
}
6271

6372
/// Computes the softmax cross entropy (categorical cross entropy) between logits and labels.

Tests/TensorFlowTests/LossTests.swift

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,15 @@ final class LossTests: XCTestCase {
7676
assertElementsEqual(expected: Tensor(expectedLoss), actual: loss)
7777
}
7878

79+
func testCategoricalHingeLoss() {
80+
let predicted = Tensor<Float>([3, 4 ,5])
81+
let expected = Tensor<Float>([0.3, 0.4, 0.3])
82+
83+
let loss = categoricalHingeLoss(predicted: predicted, expected: expected)
84+
let expectedLoss: Float = 0.5
85+
assertElementsEqual(expected: Tensor(expectedLoss), actual: loss)
86+
}
87+
7988
func testSoftmaxCrossEntropyWithProbabilitiesLoss() {
8089
let logits = Tensor<Float>(shape: [2, 4], scalars: [1, 2, 3, 4, 5, 6, 7, 8])
8190
let labels = Tensor<Float>(
@@ -168,6 +177,8 @@ final class LossTests: XCTestCase {
168177
static var allTests = [
169178
("testMeanSquaredErrorLoss", testMeanSquaredErrorLoss),
170179
("testMeanSquaredErrorGrad", testMeanSquaredErrorGrad),
180+
("testHingeLoss", testHingeLoss),
181+
("testCategoricalHingeLoss", testCategoricalHingeLoss),
171182
("testSquaredHingeLoss", testSquaredHingeLoss),
172183
("testSoftmaxCrossEntropyWithProbabilitiesLoss",
173184
testSoftmaxCrossEntropyWithProbabilitiesLoss),

0 commit comments

Comments
 (0)