@@ -1890,7 +1890,7 @@ public extension Tensor where Scalar: Numeric {
1890
1890
/// - Parameter axes: The dimensions to reduce.
1891
1891
/// - Precondition: Each value in `axes` must be in the range `-rank...rank`.
1892
1892
@inlinable
1893
- @differentiable ( wrt: self where Scalar: TensorFlowFloatingPoint)
1893
+ @differentiable ( wrt: self , vjp : _vjpMean ( squeezingAxes : ) where Scalar: TensorFlowFloatingPoint)
1894
1894
func mean( squeezingAxes axes: [ Int ] ) -> Tensor {
1895
1895
// TODO(TF-433): Remove workaround for differentiating `map`.
1896
1896
let axes = { axes. map ( Int32 . init) } ( )
@@ -1927,7 +1927,7 @@ public extension Tensor where Scalar: Numeric {
1927
1927
/// - Parameter axes: The dimensions to reduce.
1928
1928
/// - Precondition: Each value in `axes` must be in the range `-rank..<rank`.
1929
1929
@inlinable
1930
- @differentiable ( wrt: self where Scalar: TensorFlowFloatingPoint)
1930
+ @differentiable ( wrt: self , vjp : _vjpMean ( alongAxes : ) where Scalar: TensorFlowFloatingPoint)
1931
1931
func mean( alongAxes axes: [ Int ] ) -> Tensor {
1932
1932
// TODO(TF-433): Remove workaround for differentiating `map`.
1933
1933
let axes = { axes. map ( Int32 . init) } ( )
@@ -2201,6 +2201,25 @@ internal extension Tensor where Scalar: TensorFlowFloatingPoint {
2201
2201
} )
2202
2202
}
2203
2203
2204
+ @inlinable
2205
+ func _vjpMean( alongAxes axes: [ Int ] ) -> ( Tensor , ( Tensor ) -> Tensor ) {
2206
+ let value = mean ( alongAxes: axes)
2207
+ let myShape = shape
2208
+ let count = axes. map { myShape [ $0] } . reduce ( 1 , * )
2209
+ return ( value, { [ shape = shapeTensor] in $0. broadcasted ( toShape: shape) / Tensor( Scalar ( count) ) } )
2210
+ }
2211
+
2212
+ @inlinable
2213
+ func _vjpMean( squeezingAxes axes: [ Int ] ) -> ( Tensor , ( Tensor ) -> Tensor ) {
2214
+ let value = mean ( squeezingAxes: axes)
2215
+ let myShape = shape
2216
+ let count = axes. map { myShape [ $0] } . reduce ( 1 , * )
2217
+ return ( value, { [ shape = shapeTensor] v in
2218
+ let unsqueezed = v. expandingShape ( at: axes)
2219
+ return unsqueezed. broadcasted ( toShape: shape) / Tensor( Scalar ( count) )
2220
+ } )
2221
+ }
2222
+
2204
2223
@inlinable
2205
2224
func _vjpCumulativeSum(
2206
2225
alongAxis axis: Tensor < Int32 > ,
@@ -2606,7 +2625,7 @@ internal func _vjpMatmul<Scalar: TensorFlowFloatingPoint>(
2606
2625
transposed transposeRhs: Bool = false
2607
2626
) -> ( Tensor < Scalar > , ( Tensor < Scalar > ) -> ( Tensor < Scalar > , Tensor < Scalar > ) ) {
2608
2627
let value = matmul ( lhs, transposed: transposeLhs, rhs, transposed: transposeRhs)
2609
- return ( value, { [ lhsShape = lhs. shapeTensor , rhsShape = rhs. shapeTensor ] v in
2628
+ return ( value, { [ lhsShape = lhs. shape , rhsShape = rhs. shape ] v in
2610
2629
let ( lhsGrad, rhsGrad) : ( Tensor < Scalar > , Tensor < Scalar > )
2611
2630
switch ( transposeLhs, transposeRhs) {
2612
2631
case ( false , false ) :
@@ -2622,13 +2641,15 @@ internal func _vjpMatmul<Scalar: TensorFlowFloatingPoint>(
2622
2641
lhsGrad = matmul ( v, transposed: true , rhs, transposed: true )
2623
2642
rhsGrad = matmul ( lhs, transposed: true , v, transposed: true )
2624
2643
}
2625
- let lhsRank = lhsShape. shape [ 0 ] - 2
2626
- let rhsRank = rhsShape. shape [ 0 ] - 2
2644
+ let lhsRank = lhsShape. rank - 2
2645
+ let rhsRank = rhsShape. rank - 2
2627
2646
let ( lhsAxes, rhsAxes) = _Raw. broadcastGradientArgs (
2628
- s0: lhsShape [ ..< lhsRank] ,
2629
- s1: rhsShape [ ..< rhsRank] )
2630
- return ( lhsGrad. sum ( squeezingAxes: lhsAxes) . reshaped ( toShape: lhsShape) ,
2631
- rhsGrad. sum ( squeezingAxes: rhsAxes) . reshaped ( toShape: rhsShape) )
2647
+ s0: Tensor < Int32 > ( lhsShape. dimensions [ ..< lhsRank] . map { Int32 ( $0) } ) ,
2648
+ s1: Tensor < Int32 > ( rhsShape. dimensions [ ..< rhsRank] . map { Int32 ( $0) } ) )
2649
+ let lhsShapeTensor = Tensor < Int32 > ( lhsShape. dimensions. map { Int32 ( $0) } )
2650
+ let rhsShapeTensor = Tensor < Int32 > ( rhsShape. dimensions. map { Int32 ( $0) } )
2651
+ return ( lhsGrad. sum ( squeezingAxes: lhsAxes) . reshaped ( toShape: lhsShapeTensor) ,
2652
+ rhsGrad. sum ( squeezingAxes: rhsAxes) . reshaped ( toShape: rhsShapeTensor) )
2632
2653
} )
2633
2654
}
2634
2655
0 commit comments