Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 3d3fe14

Browse files
shingtrxwei
authored andcommitted
Add hinge loss function (#185)
Added hinge loss function. Ref: #127
1 parent 7a104ca commit 3d3fe14

File tree

2 files changed

+23
-0
lines changed

2 files changed

+23
-0
lines changed

Sources/TensorFlow/Loss.swift

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,18 @@ public func meanAbsoluteError<Scalar: TensorFlowFloatingPoint>(
3636
return abs(expected - predicted).mean()
3737
}
3838

39+
/// Returns the hinge loss between predictions and expectations.
40+
///
41+
/// - Parameters:
42+
/// - predicted: Predicted outputs from a neural network.
43+
/// - expected: Expected values, i.e. targets, that correspond to the correct output.
44+
@differentiable(wrt: predicted)
45+
public func hingeLoss<Scalar: TensorFlowFloatingPoint>(
46+
predicted: Tensor<Scalar>, expected: Tensor<Scalar>
47+
) -> Tensor<Scalar> {
48+
return max(Tensor(1) - expected * predicted, Tensor(0)).mean()
49+
}
50+
3951
/// Computes the softmax cross entropy (categorical cross entropy) between logits and labels.
4052
///
4153
/// - Parameters:

Tests/TensorFlowTests/LossTests.swift

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,17 @@ final class LossTests: XCTestCase {
5757
assertElementsEqual(expected: expectedGradients, actual: gradients)
5858
}
5959

60+
func testHingeLoss() {
61+
let predicted = Tensor<Float>(shape: [2, 4], scalars: [1, 2, 3, 4, 5, 6, 7, 8])
62+
let expected = Tensor<Float>(
63+
shape: [2, 4],
64+
scalars: [0.1, 0.2, 0.3, 0.4, 0.4, 0.3, 0.2, 0.1])
65+
66+
let loss = hingeLoss(predicted: predicted, expected: expected)
67+
let expectedLoss: Float = 0.225
68+
assertElementsEqual(expected: Tensor(expectedLoss), actual: loss)
69+
}
70+
6071
func testSoftmaxCrossEntropyWithProbabilitiesLoss() {
6172
let logits = Tensor<Float>(shape: [2, 4], scalars: [1, 2, 3, 4, 5, 6, 7, 8])
6273
let labels = Tensor<Float>(

0 commit comments

Comments
 (0)