Skip to content

Commit 024062d

Browse files
Descartessdan12411
authored andcommitted
Add Cosine Similarity (tensorflow#233)
1 parent a3f389f commit 024062d

File tree

2 files changed

+22
-0
lines changed

2 files changed

+22
-0
lines changed

Sources/TensorFlow/Loss.swift

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,19 @@ public func hingeLoss<Scalar: TensorFlowFloatingPoint>(
6565
return max(Tensor(1) - expected * predicted, Tensor(0)).mean()
6666
}
6767

68+
/// Returns the cosine similarity between predictions and expectations.
69+
///
70+
/// - Parameters:
71+
/// - predicted: Predicted outputs from a neural network.
72+
/// - expected: Expected values, i.e. targets, that correspond to the correct output.
73+
@differentiable(wrt: (predicted, expected))
74+
public func cosineSimilarity<Scalar: TensorFlowFloatingPoint>(
75+
predicted: Tensor<Scalar>, expected: Tensor<Scalar>
76+
) -> Tensor<Scalar> {
77+
return -(expected * predicted).sum() /
78+
(sqrt(expected.squared().sum()) * sqrt(predicted.squared().sum()))
79+
}
80+
6881
/// Returns the squared hinge loss between predictions and expectations.
6982
///
7083
/// - Parameters:

Tests/TensorFlowTests/LossTests.swift

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,14 @@ final class LossTests: XCTestCase {
7878
let expectedLoss: Float = 0.225
7979
assertElementsEqual(expected: Tensor(expectedLoss), actual: loss)
8080
}
81+
82+
func testCosineSimilarityLoss() {
83+
let predicted = Tensor<Float>([1, 2, 3, 4, 5, 6, 7, 8])
84+
let expected = Tensor<Float>([0.5, 1, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0])
85+
let loss = cosineSimilarity(predicted: predicted, expected: expected)
86+
let expectedLoss: Float = -1.0
87+
assertElementsEqual(expected: Tensor(expectedLoss), actual: loss)
88+
}
8189

8290
func testSquaredHingeLoss() {
8391
let predicted = Tensor<Float>([1, 2, 3, 4, 5, 6, 7, 8])
@@ -209,6 +217,7 @@ final class LossTests: XCTestCase {
209217
("testHingeLoss", testHingeLoss),
210218
("testKullbackLeiblerDivergence", testKullbackLeiblerDivergence),
211219
("testCategoricalHingeLoss", testCategoricalHingeLoss),
220+
("testCosineSimilarityLoss", testCosineSimilarityLoss),
212221
("testSquaredHingeLoss", testSquaredHingeLoss),
213222
("testPoissonLoss", testPoissonLoss),
214223
("testSoftmaxCrossEntropyWithProbabilitiesLoss",

0 commit comments

Comments
 (0)