@@ -210,7 +210,10 @@ extension Tensor where Scalar : TensorFlowFloatingPoint {
210
210
) -> ( Tensor , ( Tensor ) -> ( Tensor , Tensor ) ) {
211
211
return ( lhs + rhs, {
212
212
[ lhsShape = lhs. shapeTensor, rhsShape = rhs. shapeTensor] v in
213
- return ( v. unbroadcast ( toShape: lhsShape) , v. unbroadcast ( toShape: rhsShape) )
213
+ let ( lhsAxes, rhsAxes) =
214
+ Raw . broadcastGradientArgs ( s0: lhsShape, s1: rhsShape)
215
+ return ( v. sum ( squeezingAxes: lhsAxes) . reshaped ( toShape: lhsShape) ,
216
+ v. sum ( squeezingAxes: rhsAxes) . reshaped ( toShape: rhsShape) )
214
217
} )
215
218
}
216
219
@@ -220,30 +223,38 @@ extension Tensor where Scalar : TensorFlowFloatingPoint {
220
223
) -> ( Tensor , ( Tensor ) -> ( Tensor , Tensor ) ) {
221
224
return ( lhs - rhs, {
222
225
[ lhsShape = lhs. shapeTensor, rhsShape = rhs. shapeTensor] v in
223
- return ( v. unbroadcast ( toShape: lhsShape) ,
224
- - v. unbroadcast ( toShape: rhsShape) )
226
+ let ( lhsAxes, rhsAxes) =
227
+ Raw . broadcastGradientArgs ( s0: lhsShape, s1: rhsShape)
228
+ return ( v. sum ( squeezingAxes: lhsAxes) . reshaped ( toShape: lhsShape) ,
229
+ - v. sum ( squeezingAxes: rhsAxes) . reshaped ( toShape: rhsShape) )
225
230
} )
226
231
}
227
232
228
233
@inlinable
229
234
static func _vjpMultiply(
230
235
lhs: Tensor , rhs: Tensor
231
236
) -> ( Tensor , ( Tensor ) -> ( Tensor , Tensor ) ) {
232
- return ( lhs * rhs, {
233
- [ lhsShape = lhs. shapeTensor, rhsShape = rhs. shapeTensor] v in
234
- ( ( rhs * v) . unbroadcast ( toShape: lhsShape) ,
235
- ( lhs * v) . unbroadcast ( toShape: rhsShape) )
237
+ return ( lhs * rhs, { v in
238
+ let ( lhsShape, rhsShape) = ( lhs. shapeTensor, rhs. shapeTensor)
239
+ let ( lhsAxes, rhsAxes) =
240
+ Raw . broadcastGradientArgs ( s0: lhsShape, s1: rhsShape)
241
+ return ( ( rhs * v) . sum ( squeezingAxes: lhsAxes) . reshaped ( toShape: lhsShape) ,
242
+ ( lhs * v) . sum ( squeezingAxes: rhsAxes) . reshaped ( toShape: rhsShape) )
236
243
} )
237
244
}
238
245
239
246
@inlinable
240
247
static func _vjpDivide(
241
248
lhs: Tensor , rhs: Tensor
242
249
) -> ( Tensor , ( Tensor ) -> ( Tensor , Tensor ) ) {
243
- return ( lhs / rhs, {
244
- [ lhsShape = lhs. shapeTensor, rhsShape = rhs. shapeTensor] v in
245
- ( ( v / rhs) . unbroadcast ( toShape: lhsShape) ,
246
- ( ( - lhs) / rhs. squared ( ) * v) . unbroadcast ( toShape: rhsShape) )
250
+ return ( lhs / rhs, { v in
251
+ let ( lhsShape, rhsShape) = ( lhs. shapeTensor, rhs. shapeTensor)
252
+ let ( lhsAxes, rhsAxes) =
253
+ Raw . broadcastGradientArgs ( s0: lhsShape, s1: rhsShape)
254
+ return ( ( v / rhs) . sum ( squeezingAxes: lhsAxes)
255
+ . reshaped ( toShape: lhsShape) ,
256
+ ( - lhs / rhs. squared ( ) * v) . sum ( squeezingAxes: rhsAxes)
257
+ . reshaped ( toShape: rhsShape) )
247
258
} )
248
259
}
249
260
}
@@ -267,14 +278,14 @@ extension Tensor where Scalar : TensorFlowFloatingPoint {
267
278
static func _vjpSubtract(
268
279
lhs: Tensor , rhs: Scalar
269
280
) -> ( Tensor , ( Tensor ) -> ( Tensor , Scalar ) ) {
270
- return ( lhs - rhs, { v in ( v, 0 - v. sum ( ) . scalarized ( ) ) } )
281
+ return ( lhs - rhs, { v in ( v, - v. sum ( ) . scalarized ( ) ) } )
271
282
}
272
283
273
284
@inlinable
274
285
static func _vjpSubtract(
275
286
lhs: Scalar , rhs: Tensor
276
287
) -> ( Tensor , ( Tensor ) -> ( Scalar , Tensor ) ) {
277
- return ( lhs - rhs, { v in ( v. sum ( ) . scalarized ( ) , 0 - v) } )
288
+ return ( lhs - rhs, { v in ( v. sum ( ) . scalarized ( ) , - v) } )
278
289
}
279
290
280
291
@inlinable
@@ -296,7 +307,7 @@ extension Tensor where Scalar : TensorFlowFloatingPoint {
296
307
lhs: Tensor , rhs: Scalar
297
308
) -> ( Tensor , ( Tensor ) -> ( Tensor , Scalar ) ) {
298
309
return ( lhs / rhs, { v in
299
- ( v / rhs, ( v * ( 0 - lhs) / Tensor( rhs) . squared ( ) ) . sum ( ) . scalarized ( ) )
310
+ ( v / rhs, ( v * - lhs / Tensor( rhs) . squared ( ) ) . sum ( ) . scalarized ( ) )
300
311
} )
301
312
}
302
313
@@ -317,25 +328,30 @@ func _vjpMinMaxHelper<T : TensorFlowFloatingPoint>(
317
328
let denom = 1 + Tensor < T > ( x .== y)
318
329
let dfdx = vector * Tensor < T > ( x .== originalValue) / denom
319
330
let dfdy = vector * Tensor < T > ( y .== originalValue) / denom
320
- return ( dfdx. unbroadcast ( like: x) , dfdy. unbroadcast ( like: y) )
331
+ let ( xShape, yShape) = ( x. shapeTensor, y. shapeTensor)
332
+ let ( xAxes, yAxes) = Raw . broadcastGradientArgs ( s0: xShape, s1: yShape)
333
+ return ( dfdx. sum ( squeezingAxes: xAxes) . reshaped ( toShape: xShape) ,
334
+ dfdy. sum ( squeezingAxes: yAxes) . reshaped ( toShape: yShape) )
321
335
}
322
336
323
337
@inlinable
324
338
func _vjpMax< T : TensorFlowFloatingPoint > (
325
339
_ x: Tensor < T > , _ y: Tensor < T >
326
340
) -> ( Tensor < T > , ( Tensor < T > ) -> ( Tensor < T > , Tensor < T > ) ) {
327
341
let value = max ( x, y)
328
- return ( value,
329
- { v in _vjpMinMaxHelper ( x, y, originalValue: value, vector: v) } )
342
+ return ( value, { v in
343
+ _vjpMinMaxHelper ( x, y, originalValue: value, vector: v)
344
+ } )
330
345
}
331
346
332
347
@inlinable
333
348
func _vjpMin< T : TensorFlowFloatingPoint > (
334
349
_ x: Tensor < T > , _ y: Tensor < T >
335
350
) -> ( Tensor < T > , ( Tensor < T > ) -> ( Tensor < T > , Tensor < T > ) ) {
336
351
let value = min ( x, y)
337
- return ( value,
338
- { v in _vjpMinMaxHelper ( x, y, originalValue: value, vector: v) } )
352
+ return ( value, { v in
353
+ _vjpMinMaxHelper ( x, y, originalValue: value, vector: v)
354
+ } )
339
355
}
340
356
341
357
@inlinable
@@ -344,8 +360,12 @@ func _vjpPow<T : TensorFlowFloatingPoint>(
344
360
) -> ( Tensor < T > , ( Tensor < T > ) -> ( Tensor < T > , Tensor < T > ) ) {
345
361
let value = pow ( x, y)
346
362
return ( value, { v in
347
- ( ( v * y * pow( x, y- 1 ) ) . unbroadcast ( like: x) ,
348
- ( v * log( x) * value) . unbroadcast ( like: y) )
363
+ let ( xShape, yShape) = ( x. shapeTensor, y. shapeTensor)
364
+ let ( xAxes, yAxes) = Raw . broadcastGradientArgs ( s0: xShape, s1: yShape)
365
+ return ( ( v * y * pow( x, y- 1 ) ) . sum ( squeezingAxes: xAxes)
366
+ . reshaped ( toShape: xShape) ,
367
+ ( v * log( x) * value) . sum ( squeezingAxes: yAxes)
368
+ . reshaped ( toShape: yShape) )
349
369
} )
350
370
}
351
371
0 commit comments