Skip to content

Commit ad40ff8

Browse files
Shashi456dan12411
authored andcommitted
Adding Squared Hinge Loss (tensorflow#187)
1 parent f55344e commit ad40ff8

File tree

2 files changed

+16
-0
lines changed

2 files changed

+16
-0
lines changed

Sources/TensorFlow/Loss.swift

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,13 @@ public func hingeLoss<Scalar: TensorFlowFloatingPoint>(
4848
return max(Tensor(1) - expected * predicted, Tensor(0)).mean()
4949
}
5050

51+
@differentiable(wrt: predicted)
52+
public func squaredHingeLoss<Scalar: TensorFlowFloatingPoint>(
53+
predicted: Tensor<Scalar>, expected: Tensor<Scalar>
54+
) -> Tensor<Scalar> {
55+
return (max(Tensor(1) - expected * predicted, Tensor(0))).squared().mean()
56+
}
57+
5158
/// Returns the hinge loss between predictions and expectations.
5259
///
5360
/// - Parameters:

Tests/TensorFlowTests/LossTests.swift

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,14 @@ final class LossTests: XCTestCase {
6868
assertElementsEqual(expected: Tensor(expectedLoss), actual: loss)
6969
}
7070

71+
func testSquaredHingeLoss() {
72+
let predicted = Tensor<Float>([1, 2, 3, 4, 5, 6, 7, 8])
73+
let expected = Tensor<Float>([0.5, 1, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0])
74+
let loss = squaredHingeLoss(predicted: predicted, expected: expected)
75+
let expectedLoss: Float = 0.03125
76+
assertElementsEqual(expected: Tensor(expectedLoss), actual: loss)
77+
}
78+
7179
func testCategoricalHingeLoss() {
7280
let predicted = Tensor<Float>([3, 4 ,5])
7381
let expected = Tensor<Float>([0.3, 0.4, 0.3])
@@ -171,6 +179,7 @@ final class LossTests: XCTestCase {
171179
("testMeanSquaredErrorGrad", testMeanSquaredErrorGrad),
172180
("testHingeLoss", testHingeLoss),
173181
("testCategoricalHingeLoss", testCategoricalHingeLoss),
182+
("testSquaredHingeLoss", testSquaredHingeLoss),
174183
("testSoftmaxCrossEntropyWithProbabilitiesLoss",
175184
testSoftmaxCrossEntropyWithProbabilitiesLoss),
176185
("testSoftmaxCrossEntropyWithProbabilitiesGrad",

0 commit comments

Comments
 (0)