Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 3a256de

Browse files
Shashi456rxwei
authored andcommitted
Adding kullback Leibler Divergence (#226)
1 parent a74064d commit 3a256de

File tree

2 files changed

+21
-0
lines changed

2 files changed

+21
-0
lines changed

Sources/TensorFlow/Loss.swift

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,18 @@ public func poissonLoss<Scalar: TensorFlowFloatingPoint>(
103103
return (predicted - expected * log(predicted)).mean()
104104
}
105105

106+
/// Returns the Kullback-Leibler divergence between predictions and expectations.
107+
///
108+
/// - Parameters:
109+
/// - predicted: Predicted outputs from a neural network.
110+
/// - expected: Expected values, i.e. targets, that correspond to the correct output.
111+
@differentiable(wrt: predicted)
112+
public func kullbackLeiblerDivergence<Scalar: TensorFlowFloatingPoint>(
113+
predicted: Tensor<Scalar>, expected: Tensor<Scalar>
114+
) -> Tensor<Scalar> {
115+
return (expected * log(expected / predicted)).sum()
116+
}
117+
106118
/// Computes the softmax cross entropy (categorical cross entropy) between logits and labels.
107119
///
108120
/// - Parameters:

Tests/TensorFlowTests/LossTests.swift

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,14 @@ final class LossTests: XCTestCase {
104104
assertElementsEqual(expected: Tensor(expectedLoss), actual: loss)
105105
}
106106

107+
func testKullbackLeiblerDivergence() {
108+
let predicted = Tensor<Float>([0.2, 0.3, 0.4])
109+
let expected = Tensor<Float>([1.0, 4.0, 3.0])
110+
let loss = kullbackLeiblerDivergence(predicted: predicted, expected: expected)
111+
let expectedLoss: Float = 18.015217
112+
assertElementsEqual(expected: Tensor(expectedLoss), actual: loss)
113+
}
114+
107115
func testSoftmaxCrossEntropyWithProbabilitiesLoss() {
108116
let logits = Tensor<Float>(shape: [2, 4], scalars: [1, 2, 3, 4, 5, 6, 7, 8])
109117
let labels = Tensor<Float>(
@@ -199,6 +207,7 @@ final class LossTests: XCTestCase {
199207
("testMeanSquaredLogarithmicError", testMeanSquaredLogarithmicError),
200208
("testMeanAbsoluteError", testMeanAbsoluteError),
201209
("testHingeLoss", testHingeLoss),
210+
("testKullbackLeiblerDivergence", testKullbackLeiblerDivergence),
202211
("testCategoricalHingeLoss", testCategoricalHingeLoss),
203212
("testSquaredHingeLoss", testSquaredHingeLoss),
204213
("testPoissonLoss", testPoissonLoss),

0 commit comments

Comments
 (0)