@@ -569,25 +569,15 @@ extension Tensor where Scalar : TensorFlowFloatingPoint {
569
569
570
570
extension Tensor where Scalar : TensorFlowFloatingPoint {
571
571
@inlinable
572
- func _vjpSum( ) -> ( Tensor , ( Tensor ) -> Tensor ) {
573
- return ( sum ( ) , { [ shape = shapeTensor] in $0. broadcast ( toShape: shape) } )
574
- }
575
-
576
- @inlinable
577
- func _vjpMean( ) -> ( Tensor , ( Tensor ) -> Tensor ) {
578
- return ( mean ( ) , { [ shape = shapeTensor, count = scalarCountTensor] in
579
- ( $0 / Tensor( count) ) . broadcast ( toShape: shape)
580
- } )
581
- }
582
-
583
- @inlinable
584
- func _vjpSum( alongAxes axes: [ Int32 ] ) -> ( Tensor , ( Tensor ) -> Tensor ) {
572
+ func _vjpSum( alongAxes axes: Tensor < Int32 > ) -> ( Tensor , ( Tensor ) -> Tensor ) {
585
573
let value = sum ( alongAxes: axes)
586
574
return ( value, { [ shape = shapeTensor] in $0. broadcast ( toShape: shape) } )
587
575
}
588
576
589
577
@inlinable
590
- func _vjpSum( squeezingAxes axes: [ Int32 ] ) -> ( Tensor , ( Tensor ) -> Tensor ) {
578
+ func _vjpSum(
579
+ squeezingAxes axes: Tensor < Int32 >
580
+ ) -> ( Tensor , ( Tensor ) -> Tensor ) {
591
581
let value = sum ( squeezingAxes: axes)
592
582
return ( value, { [ shape = shapeTensor] in $0. broadcast ( toShape: shape) } )
593
583
}
@@ -602,20 +592,13 @@ extension Tensor where Scalar : TensorFlowFloatingPoint {
602
592
}
603
593
604
594
@inlinable
605
- func _vjpMean( squeezingAxes axes: [ Int32 ] ) -> ( Tensor , ( Tensor ) -> Tensor ) {
595
+ func _vjpMean(
596
+ squeezingAxes axes: Tensor < Int32 >
597
+ ) -> ( Tensor , ( Tensor ) -> Tensor ) {
606
598
let value = mean ( squeezingAxes: axes)
607
- return ( value, { [ shape = shapeTensor,
608
- count = axes. map { shape [ $0] } . reduce ( 1 , * ) ] in
609
- $0. broadcast ( toShape: shape) / Tensor( Scalar ( count) )
610
- } )
611
- }
612
-
613
- @inlinable
614
- func _vjpMean( alongAxes axes: [ Int32 ] ) -> ( Tensor , ( Tensor ) -> Tensor ) {
615
- let value = mean ( alongAxes: axes)
616
- return ( value, { [ shape = shapeTensor,
617
- count = axes. map { shape [ $0] } . reduce ( 1 , * ) ] in
618
- $0. broadcast ( toShape: shape) / Tensor( Scalar ( count) )
599
+ let count = Raw . gather ( params: shapeTensor, indices: axes) . product ( )
600
+ return ( value, { [ shape = shapeTensor] in
601
+ $0. broadcast ( toShape: shape) / Tensor( count)
619
602
} )
620
603
}
621
604
}
0 commit comments