@@ -43,20 +43,12 @@ extension Tensor: VectorNumeric where Scalar: Numeric {
43
43
internal extension Tensor where Scalar: TensorFlowFloatingPoint {
44
44
@inlinable
45
45
static func _vjpMultiply( lhs: Tensor , rhs: Tensor ) -> ( Tensor , ( Tensor ) -> ( Tensor , Tensor ) ) {
46
- return ( lhs * rhs, { [
47
- lhsShape = lhs. shape,
48
- rhsShape = rhs. shape,
49
- lhsShapeTensor = lhs. shapeTensor,
50
- rhsShapeTensor = rhs. shapeTensor] v in
51
- var lhsGrad = rhs * v
52
- var rhsGrad = lhs * v
53
- if lhsGrad. shape != lhsShape {
54
- lhsGrad = lhsGrad. unbroadcasted ( toShape: lhsShapeTensor)
55
- }
56
- if rhsGrad. shape != rhsShape {
57
- rhsGrad = rhsGrad. unbroadcasted ( toShape: rhsShapeTensor)
58
- }
59
- return ( lhsGrad, rhsGrad)
46
+ return ( lhs * rhs, { [ lhsShape = lhs. shapeTensor, rhsShape = rhs. shapeTensor] v in
47
+ let lhsGrad = rhs * v
48
+ let rhsGrad = lhs * v
49
+ let ( lhsAxes, rhsAxes) = Raw . broadcastGradientArgs ( s0: lhsShape, s1: rhsShape)
50
+ return ( lhsGrad. sum ( squeezingAxes: lhsAxes) . reshaped ( toShape: lhsShape) ,
51
+ rhsGrad. sum ( squeezingAxes: rhsAxes) . reshaped ( toShape: rhsShape) )
60
52
} )
61
53
}
62
54
}
@@ -236,12 +228,12 @@ internal extension Tensor where Scalar: TensorFlowFloatingPoint {
236
228
237
229
@inlinable
238
230
static func _vjpSubtract( lhs: Tensor , rhs: Scalar ) -> ( Tensor , ( Tensor ) -> ( Tensor , Scalar ) ) {
239
- return ( lhs - rhs, { v in ( v, 0 - v. sum ( ) . scalarized ( ) ) } )
231
+ return ( lhs - rhs, { v in ( v, - v. sum ( ) . scalarized ( ) ) } )
240
232
}
241
233
242
234
@inlinable
243
235
static func _vjpSubtract( lhs: Scalar , rhs: Tensor ) -> ( Tensor , ( Tensor ) -> ( Scalar , Tensor ) ) {
244
- return ( lhs - rhs, { v in ( v. sum ( ) . scalarized ( ) , 0 - v) } )
236
+ return ( lhs - rhs, { v in ( v. sum ( ) . scalarized ( ) , - v) } )
245
237
}
246
238
247
239
@inlinable
@@ -256,27 +248,19 @@ internal extension Tensor where Scalar: TensorFlowFloatingPoint {
256
248
257
249
@inlinable
258
250
static func _vjpDivide( lhs: Tensor , rhs: Tensor ) -> ( Tensor , ( Tensor ) -> ( Tensor , Tensor ) ) {
259
- return ( lhs / rhs, { [
260
- lhsShape = lhs. shape,
261
- rhsShape = rhs. shape,
262
- lhsShapeTensor = lhs. shapeTensor,
263
- rhsShapeTensor = rhs. shapeTensor] v in
264
- var lhsGrad = v / rhs
265
- var rhsGrad = ( - lhs) / rhs. squared ( ) * v
266
- if lhsGrad. shape != lhsShape {
267
- lhsGrad = lhsGrad. unbroadcasted ( toShape: lhsShapeTensor)
268
- }
269
- if rhsGrad. shape != rhsShape {
270
- rhsGrad = rhsGrad. unbroadcasted ( toShape: rhsShapeTensor)
271
- }
272
- return ( lhsGrad, rhsGrad)
251
+ return ( lhs / rhs, { [ lhsShape = lhs. shapeTensor, rhsShape = rhs. shapeTensor] v in
252
+ let lhsGrad = v / rhs
253
+ let rhsGrad = - lhs / rhs. squared ( ) * v
254
+ let ( lhsAxes, rhsAxes) = Raw . broadcastGradientArgs ( s0: lhsShape, s1: rhsShape)
255
+ return ( lhsGrad. sum ( squeezingAxes: lhsAxes) . reshaped ( toShape: lhsShape) ,
256
+ rhsGrad. sum ( squeezingAxes: rhsAxes) . reshaped ( toShape: rhsShape) )
273
257
} )
274
258
}
275
259
276
260
@inlinable
277
261
static func _vjpDivide( lhs: Tensor , rhs: Scalar ) -> ( Tensor , ( Tensor ) -> ( Tensor , Scalar ) ) {
278
262
return ( lhs / rhs, { v in
279
- ( v / rhs, ( v * ( 0 - lhs) / Tensor( rhs) . squared ( ) ) . sum ( ) . scalarized ( ) )
263
+ ( v / rhs, ( v * - lhs / Tensor( rhs) . squared ( ) ) . sum ( ) . scalarized ( ) )
280
264
} )
281
265
}
282
266
@@ -704,15 +688,12 @@ internal func _vjpPow<T: TensorFlowFloatingPoint>(
704
688
let value = pow ( x, y)
705
689
return ( value, { v in
706
690
let safeX = x. replacing ( with: Tensor < T > ( onesLike: x) , where: x .<= 0 )
707
- var gradX = v * y * pow( x, y - 1 )
708
- var gradY = value * v * log( safeX)
709
- if gradX. shape != x. shape {
710
- gradX = gradX. unbroadcasted ( like: x)
711
- }
712
- if gradY. shape != y. shape {
713
- gradY = gradY. unbroadcasted ( like: y)
714
- }
715
- return ( gradX, gradY)
691
+ let lhsGrad = v * y * pow( x, y - 1 )
692
+ let rhsGrad = value * v * log( safeX)
693
+ let ( lhsShape, rhsShape) = ( x. shapeTensor, y. shapeTensor)
694
+ let ( lhsAxes, rhsAxes) = Raw . broadcastGradientArgs ( s0: lhsShape, s1: rhsShape)
695
+ return ( lhsGrad. sum ( squeezingAxes: lhsAxes) . reshaped ( toShape: lhsShape) ,
696
+ rhsGrad. sum ( squeezingAxes: rhsAxes) . reshaped ( toShape: rhsShape) )
716
697
} )
717
698
}
718
699
@@ -798,15 +779,12 @@ internal func _vjpMinMaxHelper<T: TensorFlowFloatingPoint>(
798
779
seed: Tensor < T >
799
780
) -> ( Tensor < T > , Tensor < T > ) {
800
781
let denominator = 1 + Tensor < T > ( x .== y)
801
- var gradX = seed * Tensor < T > ( x .== originalValue) / denominator
802
- var gradY = seed * Tensor < T > ( y .== originalValue) / denominator
803
- if gradX. shape != x. shape {
804
- gradX = gradX. unbroadcasted ( like: x)
805
- }
806
- if gradY. shape != y. shape {
807
- gradY = gradY. unbroadcasted ( like: y)
808
- }
809
- return ( gradX, gradY)
782
+ let lhsGrad = seed * Tensor < T > ( x .== originalValue) / denominator
783
+ let rhsGrad = seed * Tensor < T > ( y .== originalValue) / denominator
784
+ let ( lhsShape, rhsShape) = ( x. shapeTensor, y. shapeTensor)
785
+ let ( lhsAxes, rhsAxes) = Raw . broadcastGradientArgs ( s0: lhsShape, s1: rhsShape)
786
+ return ( lhsGrad. sum ( squeezingAxes: lhsAxes) . reshaped ( toShape: lhsShape) ,
787
+ rhsGrad. sum ( squeezingAxes: rhsAxes) . reshaped ( toShape: rhsShape) )
810
788
}
811
789
812
790
//===------------------------------------------------------------------------------------------===//
0 commit comments