Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 650f84e

Browse files
committed
Review
1 parent 17f5a75 commit 650f84e

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

Sources/TensorFlow/Operators/Math.swift

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -649,16 +649,16 @@ public func rsqrt<T: TensorFlowFloatingPoint>(_ x: Tensor<T>) -> Tensor<T> {
649649
public func cosineSimilarity<Scalar: TensorFlowFloatingPoint>(
650650
_ x: Tensor<Scalar>, _ y: Tensor<Scalar>
651651
) -> Tensor<Scalar> {
652-
return (x * y).sum() /
653-
(sqrt(x.squared().sum()) * sqrt(y.squared().sum()))
652+
(x * y).sum() / (sqrt(x.squared().sum()) * sqrt(y.squared().sum()))
654653
}
655654

656-
/// Returns the cosine distance between `x` and `y`.
657-
@differentiable(wrt: (x, y))
655+
/// Returns the cosine distance between `x` and `y`. Cosine distance is defined as
656+
/// `1 - cosineSimilarity(x, y)`.
657+
@differentiable
658658
public func cosineDistance<Scalar: TensorFlowFloatingPoint>(
659659
_ x: Tensor<Scalar>, _ y: Tensor<Scalar>
660660
) -> Tensor<Scalar> {
661-
return 1 - cosineSimilarity(x, y)
661+
1 - cosineSimilarity(x, y)
662662
}
663663

664664
@inlinable

0 commit comments

Comments
 (0)