@@ -1228,22 +1228,23 @@ public func max<T>(_ lhs: Tensor<T>, _ rhs: Tensor<T>) -> Tensor<T> where T: Num
1228
1228
1229
1229
@inlinable
1230
1230
internal func _vjpMax< T: TensorFlowFloatingPoint > (
1231
- _ x: Tensor < T > , _ y: Tensor < T >
1231
+ _ x: Tensor < T > ,
1232
+ _ y: Tensor < T >
1232
1233
) -> ( Tensor < T > , ( Tensor < T > ) -> ( Tensor < T > , Tensor < T > ) ) {
1233
1234
let value = max ( x, y)
1234
1235
return ( value, { v in _vjpMinMaxHelper ( x, y, originalValue: value, seed: v) } )
1235
1236
}
1236
1237
1237
1238
/// Returns the element-wise maximum of the scalar and the tensor, broadcasting the scalar.
1238
1239
@inlinable
1239
- // @differentiable(where T: TensorFlowFloatingPoint)
1240
+ @differentiable ( wrt : rhs where T: TensorFlowFloatingPoint)
1240
1241
public func max< T> ( _ lhs: T , _ rhs: Tensor < T > ) -> Tensor < T > where T: Numeric & Comparable {
1241
1242
max ( Tensor ( lhs) , rhs)
1242
1243
}
1243
1244
1244
1245
/// Returns the element-wise maximum of the scalar and the tensor, broadcasting the scalar.
1245
1246
@inlinable
1246
- // @differentiable(where T: TensorFlowFloatingPoint)
1247
+ @differentiable ( wrt : lhs where T: TensorFlowFloatingPoint)
1247
1248
public func max< T> ( _ lhs: Tensor < T > , _ rhs: T ) -> Tensor < T > where T: Numeric & Comparable {
1248
1249
max ( lhs, Tensor ( rhs) )
1249
1250
}
@@ -1258,22 +1259,23 @@ public func min<T>(_ lhs: Tensor<T>, _ rhs: Tensor<T>) -> Tensor<T> where T: Num
1258
1259
1259
1260
@inlinable
1260
1261
internal func _vjpMin< T: TensorFlowFloatingPoint > (
1261
- _ x: Tensor < T > , _ y: Tensor < T >
1262
+ _ x: Tensor < T > ,
1263
+ _ y: Tensor < T >
1262
1264
) -> ( Tensor < T > , ( Tensor < T > ) -> ( Tensor < T > , Tensor < T > ) ) {
1263
1265
let value = min ( x, y)
1264
1266
return ( value, { v in _vjpMinMaxHelper ( x, y, originalValue: value, seed: v) } )
1265
1267
}
1266
1268
1267
1269
/// Returns the element-wise minimum of the scalar and the tensor, broadcasting the scalar.
1268
1270
@inlinable
1269
- // @differentiable(where T: TensorFlowFloatingPoint)
1271
+ @differentiable ( wrt : rhs where T: TensorFlowFloatingPoint)
1270
1272
public func min< T> ( _ lhs: T , _ rhs: Tensor < T > ) -> Tensor < T > where T: Numeric & Comparable {
1271
1273
min ( Tensor ( lhs) , rhs)
1272
1274
}
1273
1275
1274
1276
/// Returns the element-wise minimum of the scalar and the tensor, broadcasting the scalar.
1275
1277
@inlinable
1276
- // @differentiable(where T: TensorFlowFloatingPoint)
1278
+ @differentiable ( wrt : lhs where T: TensorFlowFloatingPoint)
1277
1279
public func min< T> ( _ lhs: Tensor < T > , _ rhs: T ) -> Tensor < T > where T: Numeric & Comparable {
1278
1280
min ( lhs, Tensor ( rhs) )
1279
1281
}
@@ -1297,7 +1299,8 @@ internal func _vjpMinMaxHelper<T: TensorFlowFloatingPoint>(
1297
1299
/// Returns the cosine similarity between `x` and `y`.
1298
1300
@differentiable
1299
1301
public func cosineSimilarity< Scalar: TensorFlowFloatingPoint > (
1300
- _ x: Tensor < Scalar > , _ y: Tensor < Scalar >
1302
+ _ x: Tensor < Scalar > ,
1303
+ _ y: Tensor < Scalar >
1301
1304
) -> Tensor < Scalar > {
1302
1305
( x * y) . sum ( ) / ( sqrt ( x. squared ( ) . sum ( ) ) * sqrt( y. squared ( ) . sum ( ) ) )
1303
1306
}
@@ -1306,7 +1309,8 @@ public func cosineSimilarity<Scalar: TensorFlowFloatingPoint>(
1306
1309
/// `1 - cosineSimilarity(x, y)`.
1307
1310
@differentiable
1308
1311
public func cosineDistance< Scalar: TensorFlowFloatingPoint > (
1309
- _ x: Tensor < Scalar > , _ y: Tensor < Scalar >
1312
+ _ x: Tensor < Scalar > ,
1313
+ _ y: Tensor < Scalar >
1310
1314
) -> Tensor < Scalar > {
1311
1315
1 - cosineSimilarity( x, y)
1312
1316
}
0 commit comments