Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit a803823

Browse files
Shashi456rxwei
authored andcommitted
Add Categorical Hinge Loss (#188)
1 parent bf38aa8 commit a803823

File tree

2 files changed

+24
-0
lines changed

2 files changed

+24
-0
lines changed

Sources/TensorFlow/Loss.swift

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,20 @@ public func hingeLoss<Scalar: TensorFlowFloatingPoint>(
4848
return max(Tensor(1) - expected * predicted, Tensor(0)).mean()
4949
}
5050

51+
/// Returns the hinge loss between predictions and expectations.
52+
///
53+
/// - Parameters:
54+
/// - predicted: Predicted outputs from a neural network.
55+
/// - expected: Expected values, i.e. targets, that correspond to the correct output.
56+
@differentiable(wrt: predicted)
57+
public func categoricalHingeLoss<Scalar: TensorFlowFloatingPoint>(
58+
predicted: Tensor<Scalar>, expected: Tensor<Scalar>
59+
) -> Tensor<Scalar> {
60+
let positive = (expected * predicted).sum()
61+
let negative = ((Tensor(1) - expected) * predicted).max()
62+
return max(Tensor(0), negative - positive + Tensor(1))
63+
}
64+
5165
/// Computes the softmax cross entropy (categorical cross entropy) between logits and labels.
5266
///
5367
/// - Parameters:

Tests/TensorFlowTests/LossTests.swift

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,15 @@ final class LossTests: XCTestCase {
6868
assertElementsEqual(expected: Tensor(expectedLoss), actual: loss)
6969
}
7070

71+
func testCategoricalHingeLoss() {
72+
let predicted = Tensor<Float>([3, 4 ,5])
73+
let expected = Tensor<Float>([0.3, 0.4, 0.3])
74+
75+
let loss = categoricalHingeLoss(predicted: predicted, expected: expected)
76+
let expectedLoss: Float = 0.5
77+
assertElementsEqual(expected: Tensor(expectedLoss), actual: loss)
78+
}
79+
7180
func testSoftmaxCrossEntropyWithProbabilitiesLoss() {
7281
let logits = Tensor<Float>(shape: [2, 4], scalars: [1, 2, 3, 4, 5, 6, 7, 8])
7382
let labels = Tensor<Float>(
@@ -161,6 +170,7 @@ final class LossTests: XCTestCase {
161170
("testMeanSquaredErrorLoss", testMeanSquaredErrorLoss),
162171
("testMeanSquaredErrorGrad", testMeanSquaredErrorGrad),
163172
("testHingeLoss", testHingeLoss),
173+
("testCategoricalHingeLoss", testCategoricalHingeLoss),
164174
("testSoftmaxCrossEntropyWithProbabilitiesLoss",
165175
testSoftmaxCrossEntropyWithProbabilitiesLoss),
166176
("testSoftmaxCrossEntropyWithProbabilitiesGrad",

0 commit comments

Comments
 (0)