Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 17f5a75

Browse files
committed
Anthony's comments
1 parent e76de20 commit 17f5a75

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

Sources/TensorFlow/Operators/Math.swift

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -644,15 +644,23 @@ public func rsqrt<T: TensorFlowFloatingPoint>(_ x: Tensor<T>) -> Tensor<T> {
644644
Raw.rsqrt(x)
645645
}
646646

647-
/// Returns the cosine similarity between x and y.
647+
/// Returns the cosine similarity between `x` and `y`.
648648
@differentiable(wrt: (x, y))
649649
public func cosineSimilarity<Scalar: TensorFlowFloatingPoint>(
650650
_ x: Tensor<Scalar>, _ y: Tensor<Scalar>
651651
) -> Tensor<Scalar> {
652-
return -(x * y).sum() /
652+
return (x * y).sum() /
653653
(sqrt(x.squared().sum()) * sqrt(y.squared().sum()))
654654
}
655655

656+
/// Returns the cosine distance between `x` and `y`.
657+
@differentiable(wrt: (x, y))
658+
public func cosineDistance<Scalar: TensorFlowFloatingPoint>(
659+
_ x: Tensor<Scalar>, _ y: Tensor<Scalar>
660+
) -> Tensor<Scalar> {
661+
return 1 - cosineSimilarity(x, y)
662+
}
663+
656664
@inlinable
657665
internal func _vjpRsqrt<T: TensorFlowFloatingPoint>(
658666
_ x: Tensor<T>

0 commit comments

Comments
 (0)