Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 9410ec2

Browse files
committed
fixing test
1 parent 650f84e commit 9410ec2

File tree

2 files changed

+3
-4
lines changed

2 files changed

+3
-4
lines changed

Sources/TensorFlow/Operators/Math.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -645,7 +645,7 @@ public func rsqrt<T: TensorFlowFloatingPoint>(_ x: Tensor<T>) -> Tensor<T> {
645645
}
646646

647647
/// Returns the cosine similarity between `x` and `y`.
648-
@differentiable(wrt: (x, y))
648+
@differentiable
649649
public func cosineSimilarity<Scalar: TensorFlowFloatingPoint>(
650650
_ x: Tensor<Scalar>, _ y: Tensor<Scalar>
651651
) -> Tensor<Scalar> {

Tests/TensorFlowTests/OperatorTests/MathTests.swift

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,8 @@ final class MathOperatorTests: XCTestCase {
6868
func testCosineSimilarity() {
6969
let x = Tensor<Float>([1, 2, 3, 4, 5, 6, 7, 8])
7070
let y = Tensor<Float>([0.5, 1, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0])
71-
let loss = cosineSimilarity(x, y)
72-
let expectedLoss: Float = -1.0
73-
assertElementsEqual(expected: Tensor(expectedLoss), actual: loss)
71+
let z = cosineSimilarity(x, y)
72+
XCTAssertEqual(z, Tensor(1.0), accuracy: 0.0001)
7473
}
7574

7675
// FIXME(https://bugs.swift.org/browse/TF-543): Disable failing test.

0 commit comments

Comments
 (0)