Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit a3b3aa0

Browse files
eaplataniosrxwei
authored andcommitted
Added '@differentiable' to some overloads of 'min(_:_:)' and 'max(_:_:)'. (#360)
1 parent 2be7fd9 commit a3b3aa0

File tree

1 file changed

+12
-8
lines changed

1 file changed

+12
-8
lines changed

Sources/TensorFlow/Operators/Math.swift

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1228,22 +1228,23 @@ public func max<T>(_ lhs: Tensor<T>, _ rhs: Tensor<T>) -> Tensor<T> where T: Num
12281228

12291229
@inlinable
12301230
internal func _vjpMax<T: TensorFlowFloatingPoint>(
1231-
_ x: Tensor<T>, _ y: Tensor<T>
1231+
_ x: Tensor<T>,
1232+
_ y: Tensor<T>
12321233
) -> (Tensor<T>, (Tensor<T>) -> (Tensor<T>, Tensor<T>)) {
12331234
let value = max(x, y)
12341235
return (value, { v in _vjpMinMaxHelper(x, y, originalValue: value, seed: v) })
12351236
}
12361237

12371238
/// Returns the element-wise maximum of the scalar and the tensor, broadcasting the scalar.
12381239
@inlinable
1239-
// @differentiable(where T: TensorFlowFloatingPoint)
1240+
@differentiable(wrt: rhs where T: TensorFlowFloatingPoint)
12401241
public func max<T>(_ lhs: T, _ rhs: Tensor<T>) -> Tensor<T> where T: Numeric & Comparable {
12411242
max(Tensor(lhs), rhs)
12421243
}
12431244

12441245
/// Returns the element-wise maximum of the scalar and the tensor, broadcasting the scalar.
12451246
@inlinable
1246-
// @differentiable(where T: TensorFlowFloatingPoint)
1247+
@differentiable(wrt: lhs where T: TensorFlowFloatingPoint)
12471248
public func max<T>(_ lhs: Tensor<T>, _ rhs: T) -> Tensor<T> where T: Numeric & Comparable {
12481249
max(lhs, Tensor(rhs))
12491250
}
@@ -1258,22 +1259,23 @@ public func min<T>(_ lhs: Tensor<T>, _ rhs: Tensor<T>) -> Tensor<T> where T: Num
12581259

12591260
@inlinable
12601261
internal func _vjpMin<T: TensorFlowFloatingPoint>(
1261-
_ x: Tensor<T>, _ y: Tensor<T>
1262+
_ x: Tensor<T>,
1263+
_ y: Tensor<T>
12621264
) -> (Tensor<T>, (Tensor<T>) -> (Tensor<T>, Tensor<T>)) {
12631265
let value = min(x, y)
12641266
return (value, { v in _vjpMinMaxHelper(x, y, originalValue: value, seed: v) })
12651267
}
12661268

12671269
/// Returns the element-wise minimum of the scalar and the tensor, broadcasting the scalar.
12681270
@inlinable
1269-
// @differentiable(where T: TensorFlowFloatingPoint)
1271+
@differentiable(wrt: rhs where T: TensorFlowFloatingPoint)
12701272
public func min<T>(_ lhs: T, _ rhs: Tensor<T>) -> Tensor<T> where T: Numeric & Comparable {
12711273
min(Tensor(lhs), rhs)
12721274
}
12731275

12741276
/// Returns the element-wise minimum of the scalar and the tensor, broadcasting the scalar.
12751277
@inlinable
1276-
// @differentiable(where T: TensorFlowFloatingPoint)
1278+
@differentiable(wrt: lhs where T: TensorFlowFloatingPoint)
12771279
public func min<T>(_ lhs: Tensor<T>, _ rhs: T) -> Tensor<T> where T: Numeric & Comparable {
12781280
min(lhs, Tensor(rhs))
12791281
}
@@ -1297,7 +1299,8 @@ internal func _vjpMinMaxHelper<T: TensorFlowFloatingPoint>(
12971299
/// Returns the cosine similarity between `x` and `y`.
12981300
@differentiable
12991301
public func cosineSimilarity<Scalar: TensorFlowFloatingPoint>(
1300-
_ x: Tensor<Scalar>, _ y: Tensor<Scalar>
1302+
_ x: Tensor<Scalar>,
1303+
_ y: Tensor<Scalar>
13011304
) -> Tensor<Scalar> {
13021305
(x * y).sum() / (sqrt(x.squared().sum()) * sqrt(y.squared().sum()))
13031306
}
@@ -1306,7 +1309,8 @@ public func cosineSimilarity<Scalar: TensorFlowFloatingPoint>(
13061309
/// `1 - cosineSimilarity(x, y)`.
13071310
@differentiable
13081311
public func cosineDistance<Scalar: TensorFlowFloatingPoint>(
1309-
_ x: Tensor<Scalar>, _ y: Tensor<Scalar>
1312+
_ x: Tensor<Scalar>,
1313+
_ y: Tensor<Scalar>
13101314
) -> Tensor<Scalar> {
13111315
1 - cosineSimilarity(x, y)
13121316
}

0 commit comments

Comments
 (0)