Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 5a5ed0c

Browse files
committed
moving cosine similarity to math ops
1 parent cc648b0 commit 5a5ed0c

File tree

4 files changed

+19
-23
lines changed

4 files changed

+19
-23
lines changed

Sources/TensorFlow/Loss.swift

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -102,19 +102,6 @@ public func hingeLoss<Scalar: TensorFlowFloatingPoint>(
102102
return max(Tensor(1) - expected * predicted, Tensor(0)).mean()
103103
}
104104

105-
/// Returns the cosine similarity between predictions and expectations.
106-
///
107-
/// - Parameters:
108-
/// - predicted: Predicted outputs from a neural network.
109-
/// - expected: Expected values, i.e. targets, that correspond to the correct output.
110-
@differentiable(wrt: (predicted, expected))
111-
public func cosineSimilarity<Scalar: TensorFlowFloatingPoint>(
112-
_ predicted: Tensor<Scalar>, _ expected: Tensor<Scalar>
113-
) -> Tensor<Scalar> {
114-
return -(expected * predicted).sum() /
115-
(sqrt(expected.squared().sum()) * sqrt(predicted.squared().sum()))
116-
}
117-
118105
/// Returns the squared hinge loss between predictions and expectations.
119106
///
120107
/// - Parameters:

Sources/TensorFlow/Operators/Math.swift

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -644,6 +644,15 @@ public func rsqrt<T: TensorFlowFloatingPoint>(_ x: Tensor<T>) -> Tensor<T> {
644644
Raw.rsqrt(x)
645645
}
646646

647+
/// Returns the cosine similarity between x and y.
648+
@differentiable(wrt: (predicted, expected))
649+
public func cosineSimilarity<Scalar: TensorFlowFloatingPoint>(
650+
_ x: Tensor<Scalar>, _ y: Tensor<Scalar>
651+
) -> Tensor<Scalar> {
652+
return -(x * y).sum() /
653+
(sqrt(x.squared().sum()) * sqrt(y.squared().sum()))
654+
}
655+
647656
@inlinable
648657
internal func _vjpRsqrt<T: TensorFlowFloatingPoint>(
649658
_ x: Tensor<T>

Tests/TensorFlowTests/LossTests.swift

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -103,14 +103,6 @@ final class LossTests: XCTestCase {
103103
let expectedLoss: Float = 0.225
104104
assertElementsEqual(expected: Tensor(expectedLoss), actual: loss)
105105
}
106-
107-
func testCosineSimilarityLoss() {
108-
let predicted = Tensor<Float>([1, 2, 3, 4, 5, 6, 7, 8])
109-
let expected = Tensor<Float>([0.5, 1, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0])
110-
let loss = cosineSimilarity(predicted: predicted, expected: expected)
111-
let expectedLoss: Float = -1.0
112-
assertElementsEqual(expected: Tensor(expectedLoss), actual: loss)
113-
}
114106

115107
func testSquaredHingeLoss() {
116108
let predicted = Tensor<Float>([1, 2, 3, 4, 5, 6, 7, 8])
@@ -245,7 +237,6 @@ final class LossTests: XCTestCase {
245237
("testHingeLoss", testHingeLoss),
246238
("testKullbackLeiblerDivergence", testKullbackLeiblerDivergence),
247239
("testCategoricalHingeLoss", testCategoricalHingeLoss),
248-
("testCosineSimilarityLoss", testCosineSimilarityLoss),
249240
("testSquaredHingeLoss", testSquaredHingeLoss),
250241
("testPoissonLoss", testPoissonLoss),
251242
("testSoftmaxCrossEntropyWithProbabilitiesLoss",

Tests/TensorFlowTests/OperatorTests/MathTests.swift

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,14 @@ final class MathOperatorTests: XCTestCase {
6565
assertEqual(y, log(1 + x), accuracy: 0.0001)
6666
}
6767

68+
func testCosineSimilarity() {
69+
let x = Tensor<Float>([1, 2, 3, 4, 5, 6, 7, 8])
70+
let y = Tensor<Float>([0.5, 1, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0])
71+
let loss = cosineSimilarity(x, y)
72+
let expectedLoss: Float = -1.0
73+
assertElementsEqual(expected: Tensor(expectedLoss), actual: loss)
74+
}
75+
6876
// FIXME(https://bugs.swift.org/browse/TF-543): Disable failing test.
6977
/*
7078
func testExpm1() {
@@ -118,7 +126,7 @@ final class MathOperatorTests: XCTestCase {
118126
x.variance(squeezingAxes: 0),
119127
Tensor(shape: [5], scalars: [0, 0, 0, 0, 0]))
120128
XCTAssertEqual(
121-
x.variance(alongAxes: 0),
129+
x.variance(alongAxes: 0),
122130
Tensor(shape: [5], scalars: [0, 0, 0, 0, 0]))
123131
XCTAssertEqual(
124132
x.variance(squeezingAxes: 1),
@@ -280,6 +288,7 @@ final class MathOperatorTests: XCTestCase {
280288
// ("testExpm1", testExpm1),
281289
("testSign", testSign),
282290
("testReduction", testReduction),
291+
("testCosineSimilarity", testCosineSimilarity)
283292
("testArgmax", testArgmax),
284293
("testCeilAndFloor", testCeilAndFloor),
285294
("testSimpleMath", testSimpleMath),

0 commit comments

Comments
 (0)