Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 83b9196

Browse files
committed
Adding logcosh loss
1 parent fb3135c commit 83b9196

File tree

2 files changed

+44
-0
lines changed

2 files changed

+44
-0
lines changed

Sources/TensorFlow/Loss.swift

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,41 @@ public func categoricalHingeLoss<Scalar: TensorFlowFloatingPoint>(
8686
return max(Tensor(0), negative - positive + Tensor(1))
8787
}
8888

89+
/// Helper function for Logcosh
90+
@differentiable(wrt: x)
91+
internal func logcosh<Scalar: TensorFlowFloatingPoint>(
92+
x: Tensor<Scalar>
93+
) -> Tensor<Scalar> {
94+
let y = Tensor<Scalar>([2])
95+
return x + softplus(Tensor(-2) * x) - log(y)
96+
}
97+
98+
/// Returns the Logcosh loss between predictions and expectations.
99+
///
100+
/// - Parameters:
101+
/// - predicted: Predicted outputs from a neural network.
102+
/// - expected: Expected values, i.e. targets, that correspond to the correct output.
103+
@differentiable(wrt: predicted)
104+
public func logcoshLoss<Scalar: TensorFlowFloatingPoint>(
105+
predicted: Tensor<Scalar>, expected: Tensor<Scalar>
106+
) -> Tensor<Scalar> {
107+
return (logcosh(x: predicted - expected)).mean()
108+
}
109+
110+
111+
/// Returns the Poisson loss between predictions and expectations.
112+
///
113+
/// - Parameters:
114+
/// - predicted: Predicted outputs from a neural network.
115+
/// - expected: Expected values, i.e. targets, that correspond to the correct output.
116+
@differentiable(wrt: predicted)
117+
public func poissonLoss<Scalar: TensorFlowFloatingPoint>(
118+
predicted: Tensor<Scalar>, expected: Tensor<Scalar>
119+
) -> Tensor<Scalar> {
120+
return (predicted - expected * log(predicted)).mean()
121+
}
122+
123+
89124
/// Returns the Poisson loss between predictions and expectations.
90125
///
91126
/// - Parameters:

Tests/TensorFlowTests/LossTests.swift

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,14 @@ final class LossTests: XCTestCase {
9696
assertElementsEqual(expected: Tensor(expectedLoss), actual: loss)
9797
}
9898

99+
func testLogcoshLoss() {
100+
let predicted = Tensor<Float>([0.2, 0.3, 0.4])
101+
let expected = Tensor<Float>([1.0, 4.0, 3.0])
102+
let loss = logcoshLoss(predicted: predicted, expected: expected)
103+
let expectedLoss: Float = 1.7368573
104+
assertElementsEqual(expected: Tensor(expectedLoss), actual: loss)
105+
}
106+
99107
func testPoissonLoss() {
100108
let predicted = Tensor<Float>([0.1, 0.2, 0.3])
101109
let expected = Tensor<Float>([1, 2, 3])
@@ -202,6 +210,7 @@ final class LossTests: XCTestCase {
202210
("testCategoricalHingeLoss", testCategoricalHingeLoss),
203211
("testSquaredHingeLoss", testSquaredHingeLoss),
204212
("testPoissonLoss",testPoissonLoss),
213+
("testLogcoshLoss", testLogcoshLoss),
205214
("testSoftmaxCrossEntropyWithProbabilitiesLoss",
206215
testSoftmaxCrossEntropyWithProbabilitiesLoss),
207216
("testSoftmaxCrossEntropyWithProbabilitiesGrad",

0 commit comments

Comments
 (0)