@@ -1890,7 +1890,7 @@ public extension Tensor where Scalar: Numeric {
1890
1890
/// - Parameter axes: The dimensions to reduce.
1891
1891
/// - Precondition: Each value in `axes` must be in the range `-rank...rank`.
1892
1892
@inlinable
1893
- @differentiable ( wrt: self where Scalar: TensorFlowFloatingPoint)
1893
+ @differentiable ( wrt: self , vjp : _vjpMean ( squeezingAxes : ) where Scalar: TensorFlowFloatingPoint)
1894
1894
func mean( squeezingAxes axes: [ Int ] ) -> Tensor {
1895
1895
// TODO(TF-433): Remove workaround for differentiating `map`.
1896
1896
let axes = { axes. map ( Int32 . init) } ( )
@@ -1927,7 +1927,7 @@ public extension Tensor where Scalar: Numeric {
1927
1927
/// - Parameter axes: The dimensions to reduce.
1928
1928
/// - Precondition: Each value in `axes` must be in the range `-rank..<rank`.
1929
1929
@inlinable
1930
- @differentiable ( wrt: self where Scalar: TensorFlowFloatingPoint)
1930
+ @differentiable ( wrt: self , vjp : _vjpMean ( alongAxes : ) where Scalar: TensorFlowFloatingPoint)
1931
1931
func mean( alongAxes axes: [ Int ] ) -> Tensor {
1932
1932
// TODO(TF-433): Remove workaround for differentiating `map`.
1933
1933
let axes = { axes. map ( Int32 . init) } ( )
@@ -2201,6 +2201,31 @@ internal extension Tensor where Scalar: TensorFlowFloatingPoint {
2201
2201
} )
2202
2202
}
2203
2203
2204
+ // Specialization to avoid _Raw.gather on shapes when axes is known to be
2205
+ // [Int].
2206
+ @inlinable
2207
+ func _vjpMean( alongAxes axes: [ Int ] ) -> ( Tensor , ( Tensor ) -> Tensor ) {
2208
+ let value = mean ( alongAxes: axes)
2209
+ // Cache shape because it is a computed property.
2210
+ let cachedShape = shape
2211
+ let count = axes. map { cachedShape [ $0] } . reduce ( 1 , * )
2212
+ return ( value, { [ shape = shapeTensor] in $0. broadcasted ( toShape: shape) / Tensor( Scalar ( count) ) } )
2213
+ }
2214
+
2215
+ // Specialization to avoid _Raw.gather on shapes when axes is known to be
2216
+ // [Int].
2217
+ @inlinable
2218
+ func _vjpMean( squeezingAxes axes: [ Int ] ) -> ( Tensor , ( Tensor ) -> Tensor ) {
2219
+ let value = mean ( squeezingAxes: axes)
2220
+ // Cache shape because it is a computed property.
2221
+ let cachedShape = shape
2222
+ let count = axes. map { cachedShape [ $0] } . reduce ( 1 , * )
2223
+ return ( value, { [ shape = shapeTensor] v in
2224
+ let unsqueezed = v. expandingShape ( at: axes)
2225
+ return unsqueezed. broadcasted ( toShape: shape) / Tensor( Scalar ( count) )
2226
+ } )
2227
+ }
2228
+
2204
2229
@inlinable
2205
2230
func _vjpCumulativeSum(
2206
2231
alongAxis axis: Tensor < Int32 > ,
@@ -2606,7 +2631,7 @@ internal func _vjpMatmul<Scalar: TensorFlowFloatingPoint>(
2606
2631
transposed transposeRhs: Bool = false
2607
2632
) -> ( Tensor < Scalar > , ( Tensor < Scalar > ) -> ( Tensor < Scalar > , Tensor < Scalar > ) ) {
2608
2633
let value = matmul ( lhs, transposed: transposeLhs, rhs, transposed: transposeRhs)
2609
- return ( value, { [ lhsShape = lhs. shapeTensor , rhsShape = rhs. shapeTensor ] v in
2634
+ return ( value, { [ lhsShape = lhs. shape , rhsShape = rhs. shape ] v in
2610
2635
let ( lhsGrad, rhsGrad) : ( Tensor < Scalar > , Tensor < Scalar > )
2611
2636
switch ( transposeLhs, transposeRhs) {
2612
2637
case ( false , false ) :
@@ -2622,13 +2647,15 @@ internal func _vjpMatmul<Scalar: TensorFlowFloatingPoint>(
2622
2647
lhsGrad = matmul ( v, transposed: true , rhs, transposed: true )
2623
2648
rhsGrad = matmul ( lhs, transposed: true , v, transposed: true )
2624
2649
}
2625
- let lhsRank = lhsShape. shape [ 0 ] - 2
2626
- let rhsRank = rhsShape. shape [ 0 ] - 2
2650
+ let lhsRank = lhsShape. rank - 2
2651
+ let rhsRank = rhsShape. rank - 2
2627
2652
let ( lhsAxes, rhsAxes) = _Raw. broadcastGradientArgs (
2628
- s0: lhsShape [ ..< lhsRank] ,
2629
- s1: rhsShape [ ..< rhsRank] )
2630
- return ( lhsGrad. sum ( squeezingAxes: lhsAxes) . reshaped ( toShape: lhsShape) ,
2631
- rhsGrad. sum ( squeezingAxes: rhsAxes) . reshaped ( toShape: rhsShape) )
2653
+ s0: Tensor < Int32 > ( lhsShape. dimensions [ ..< lhsRank] . map { Int32 ( $0) } ) ,
2654
+ s1: Tensor < Int32 > ( rhsShape. dimensions [ ..< rhsRank] . map { Int32 ( $0) } ) )
2655
+ let lhsShapeTensor = Tensor < Int32 > ( lhsShape. dimensions. map { Int32 ( $0) } )
2656
+ let rhsShapeTensor = Tensor < Int32 > ( rhsShape. dimensions. map { Int32 ( $0) } )
2657
+ return ( lhsGrad. sum ( squeezingAxes: lhsAxes) . reshaped ( toShape: lhsShapeTensor) ,
2658
+ rhsGrad. sum ( squeezingAxes: rhsAxes) . reshaped ( toShape: rhsShapeTensor) )
2632
2659
} )
2633
2660
}
2634
2661
0 commit comments