@@ -36,7 +36,7 @@ public extension PointwiseMultiplicative {
36
36
}
37
37
38
38
//===------------------------------------------------------------------------------------------===//
39
- // Generic elementary functions
39
+ // Generic Elementary Functions
40
40
//===------------------------------------------------------------------------------------------===//
41
41
42
42
extension Tensor : ElementaryFunctions where Scalar: TensorFlowFloatingPoint {
@@ -494,6 +494,58 @@ public extension Tensor where Scalar == Bool {
494
494
}
495
495
}
496
496
497
+ public extension Tensor where Scalar: TensorFlowNumeric {
498
+ /// Returns `max(min(self, max), min)`.
499
+ @inlinable
500
+ @differentiable ( vjp: _vjpClipped where Scalar: TensorFlowFloatingPoint)
501
+ func clipped( min: Tensor , max: Tensor ) -> Tensor {
502
+ Raw . clipByValue ( t: self , clipValueMin: min, clipValueMax: max)
503
+ }
504
+
505
+ /// Returns `max(min(self, max), min)`.
506
+ @inlinable
507
+ @differentiable ( wrt: ( self , min) where Scalar: TensorFlowFloatingPoint)
508
+ func clipped( min: Tensor , max: Scalar ) -> Tensor {
509
+ clipped ( min: min, max: Tensor ( max) )
510
+ }
511
+
512
+ /// Returns `max(min(self, max), min)`.
513
+ @inlinable
514
+ @differentiable ( wrt: ( self , max) where Scalar: TensorFlowFloatingPoint)
515
+ func clipped( min: Scalar , max: Tensor ) -> Tensor {
516
+ clipped ( min: Tensor ( min) , max: max)
517
+ }
518
+
519
+ /// Returns `max(min(self, max), min)`.
520
+ @inlinable
521
+ @differentiable ( wrt: self where Scalar: TensorFlowFloatingPoint)
522
+ func clipped( min: Scalar , max: Scalar ) -> Tensor {
523
+ clipped ( min: Tensor ( min) , max: Tensor ( max) )
524
+ }
525
+ }
526
+
527
+ internal extension Tensor where Scalar: TensorFlowFloatingPoint {
528
+ @inlinable
529
+ func _vjpClipped( min: Tensor , max: Tensor ) -> ( Tensor , ( Tensor ) -> ( Tensor , Tensor , Tensor ) ) {
530
+ ( clipped ( min: min, max: max) , { v in
531
+ let selfShape = self . shapeTensor
532
+ let minShape = min. shapeTensor
533
+ let maxShape = max. shapeTensor
534
+ let zeros = Tensor ( zerosLike: v)
535
+ let minMask = self .< min
536
+ let maxMask = self .> max
537
+ let selfGradient = v. replacing ( with: zeros, where: minMask. elementsLogicalOr ( maxMask) )
538
+ let minGradient = zeros. replacing ( with: v, where: minMask)
539
+ let maxGradient = zeros. replacing ( with: v, where: maxMask)
540
+ let ( selfAxes, minAxes) = Raw . broadcastGradientArgs ( s0: selfShape, s1: minShape)
541
+ let ( _, maxAxes) = Raw . broadcastGradientArgs ( s0: selfShape, s1: maxShape)
542
+ return ( selfGradient. sum ( squeezingAxes: selfAxes) . reshaped ( toShape: selfShape) ,
543
+ minGradient. sum ( squeezingAxes: minAxes) . reshaped ( toShape: minShape) ,
544
+ maxGradient. sum ( squeezingAxes: maxAxes) . reshaped ( toShape: maxShape) )
545
+ } )
546
+ }
547
+ }
548
+
497
549
//===------------------------------------------------------------------------------------------===//
498
550
// Element-wise Unary Math Functions
499
551
//===------------------------------------------------------------------------------------------===//
0 commit comments