Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit c108a10

Browse files
Shashi456rxwei
authored andcommitted
Refactor cosine similarity and add cosine distance (#240)
1 parent 5c5f557 commit c108a10

File tree

4 files changed

+26
-22
lines changed

4 files changed

+26
-22
lines changed

Sources/TensorFlow/Loss.swift

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -102,19 +102,6 @@ public func hingeLoss<Scalar: TensorFlowFloatingPoint>(
102102
return max(Tensor(1) - expected * predicted, Tensor(0)).mean()
103103
}
104104

105-
/// Returns the cosine similarity between predictions and expectations.
106-
///
107-
/// - Parameters:
108-
/// - predicted: Predicted outputs from a neural network.
109-
/// - expected: Expected values, i.e. targets, that correspond to the correct output.
110-
@differentiable(wrt: (predicted, expected))
111-
public func cosineSimilarity<Scalar: TensorFlowFloatingPoint>(
112-
predicted: Tensor<Scalar>, expected: Tensor<Scalar>
113-
) -> Tensor<Scalar> {
114-
return -(expected * predicted).sum() /
115-
(sqrt(expected.squared().sum()) * sqrt(predicted.squared().sum()))
116-
}
117-
118105
/// Returns the squared hinge loss between predictions and expectations.
119106
///
120107
/// - Parameters:

Sources/TensorFlow/Operators/Math.swift

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -774,6 +774,23 @@ public func rsqrt<T: TensorFlowFloatingPoint>(_ x: Tensor<T>) -> Tensor<T> {
774774
Raw.rsqrt(x)
775775
}
776776

777+
/// Returns the cosine similarity between `x` and `y`.
778+
@differentiable
779+
public func cosineSimilarity<Scalar: TensorFlowFloatingPoint>(
780+
_ x: Tensor<Scalar>, _ y: Tensor<Scalar>
781+
) -> Tensor<Scalar> {
782+
(x * y).sum() / (sqrt(x.squared().sum()) * sqrt(y.squared().sum()))
783+
}
784+
785+
/// Returns the cosine distance between `x` and `y`. Cosine distance is defined as
786+
/// `1 - cosineSimilarity(x, y)`.
787+
@differentiable
788+
public func cosineDistance<Scalar: TensorFlowFloatingPoint>(
789+
_ x: Tensor<Scalar>, _ y: Tensor<Scalar>
790+
) -> Tensor<Scalar> {
791+
1 - cosineSimilarity(x, y)
792+
}
793+
777794
@inlinable
778795
internal func _vjpRsqrt<T: TensorFlowFloatingPoint>(
779796
_ x: Tensor<T>

Tests/TensorFlowTests/LossTests.swift

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -103,14 +103,6 @@ final class LossTests: XCTestCase {
103103
let expectedLoss: Float = 0.225
104104
assertElementsEqual(expected: Tensor(expectedLoss), actual: loss)
105105
}
106-
107-
func testCosineSimilarityLoss() {
108-
let predicted = Tensor<Float>([1, 2, 3, 4, 5, 6, 7, 8])
109-
let expected = Tensor<Float>([0.5, 1, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0])
110-
let loss = cosineSimilarity(predicted: predicted, expected: expected)
111-
let expectedLoss: Float = -1.0
112-
assertElementsEqual(expected: Tensor(expectedLoss), actual: loss)
113-
}
114106

115107
func testSquaredHingeLoss() {
116108
let predicted = Tensor<Float>([1, 2, 3, 4, 5, 6, 7, 8])
@@ -245,7 +237,6 @@ final class LossTests: XCTestCase {
245237
("testHingeLoss", testHingeLoss),
246238
("testKullbackLeiblerDivergence", testKullbackLeiblerDivergence),
247239
("testCategoricalHingeLoss", testCategoricalHingeLoss),
248-
("testCosineSimilarityLoss", testCosineSimilarityLoss),
249240
("testSquaredHingeLoss", testSquaredHingeLoss),
250241
("testPoissonLoss", testPoissonLoss),
251242
("testSoftmaxCrossEntropyWithProbabilitiesLoss",

Tests/TensorFlowTests/OperatorTests/MathTests.swift

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,14 @@ final class MathOperatorTests: XCTestCase {
6565
assertEqual(y, log(1 + x), accuracy: 0.0001)
6666
}
6767

68+
func testCosineSimilarity() {
69+
let x = Tensor<Float>([1, 2, 3, 4, 5, 6, 7, 8])
70+
let y = Tensor<Float>([0.5, 1, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0])
71+
let z = cosineSimilarity(x, y)
72+
let output: Float = 1.0
73+
XCTAssertEqual(z, Tensor(output))
74+
}
75+
6876
// FIXME(https://bugs.swift.org/browse/TF-543): Disable failing test.
6977
/*
7078
func testExpm1() {
@@ -294,6 +302,7 @@ final class MathOperatorTests: XCTestCase {
294302
// ("testExpm1", testExpm1),
295303
("testSign", testSign),
296304
("testReduction", testReduction),
305+
("testCosineSimilarity", testCosineSimilarity),
297306
("testArgmax", testArgmax),
298307
("testSoftplus", testSoftplus),
299308
("testSoftsign", testSoftsign),

0 commit comments

Comments
 (0)