Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 9367aeb

Browse files
committed
Adding Squared Hinge loss and test:
1 parent 3d3fe14 commit 9367aeb

File tree

2 files changed

+20
-0
lines changed

2 files changed

+20
-0
lines changed

Sources/TensorFlow/Loss.swift

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,18 @@ public func hingeLoss<Scalar: TensorFlowFloatingPoint>(
4848
return max(Tensor(1) - expected * predicted, Tensor(0)).mean()
4949
}
5050

51+
/// Returns the squared hinge loss between predictions and expectations.
52+
///
53+
/// - Parameters:
54+
/// - predicted: Predicted outputs from a neural network.
55+
/// - expected: Expected values, i.e. targets, that correspond to the correct output.
56+
@differentiable(wrt: predicted)
57+
public func squaredHingeLoss<Scalar: TensorFlowFloatingPoint>(
58+
predicted: Tensor<Scalar>, expected: Tensor<Scalar>
59+
) -> Tensor<Scalar> {
60+
return (max(Tensor(1) - expected * predicted, Tensor(0))).squared().mean()
61+
}
62+
5163
/// Computes the softmax cross entropy (categorical cross entropy) between logits and labels.
5264
///
5365
/// - Parameters:

Tests/TensorFlowTests/LossTests.swift

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,14 @@ final class LossTests: XCTestCase {
6868
assertElementsEqual(expected: Tensor(expectedLoss), actual: loss)
6969
}
7070

71+
func testSquaredHingeLoss() {
72+
let predicted = Tensor<Float>([1, 2, 3, 4, 5, 6, 7, 8])
73+
let expected = Tensor<Float>([0.5, 1, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0])
74+
let loss = squaredHingeLoss(predicted: predicted, expected: expected)
75+
let expectedLoss: Float = 0.03125
76+
assertElementsEqual(expected: Tensor(expectedLoss), actual: loss)
77+
}
78+
7179
func testSoftmaxCrossEntropyWithProbabilitiesLoss() {
7280
let logits = Tensor<Float>(shape: [2, 4], scalars: [1, 2, 3, 4, 5, 6, 7, 8])
7381
let labels = Tensor<Float>(

0 commit comments

Comments
 (0)