Skip to content

Commit b608a86

Browse files
sguggerrxwei
authored andcommitted
[TF] Refactor gradients computation with expandingShape (#24249)
* Refactor with new expandingShape function * Add space Co-Authored-By: sgugger <[email protected]> * Add space Co-Authored-By: sgugger <[email protected]>
1 parent 0ce7594 commit b608a86

File tree

1 file changed

+6
-8
lines changed

1 file changed

+6
-8
lines changed

stdlib/public/TensorFlow/Gradients.swift

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -577,10 +577,9 @@ extension Tensor where Scalar : TensorFlowFloatingPoint {
577577
squeezingAxes axes: Tensor<Int32>
578578
) -> (Tensor, (Tensor) -> Tensor) {
579579
let value = sum(squeezingAxes: axes)
580-
return (value, { [shape = shapeTensor] in
581-
var res = $0
582-
for i in axes.array.scalars { res = res.expandingShape(at: Int(i)) }
583-
return res.broadcast(toShape: shape)
580+
return (value, { [shape = shapeTensor] v in
581+
let unsqueezed = v.expandingShape(at: axes.scalars.map { Int($0) })
582+
return unsqueezed.broadcast(toShape: shape)
584583
})
585584
}
586585

@@ -599,10 +598,9 @@ extension Tensor where Scalar : TensorFlowFloatingPoint {
599598
) -> (Tensor, (Tensor) -> Tensor) {
600599
let value = mean(squeezingAxes: axes)
601600
let count = Raw.gather(params: shapeTensor, indices: axes).product()
602-
return (value, { [shape = shapeTensor] in
603-
var res = $0
604-
for i in axes.array.scalars { res = res.expandingShape(at: Int(i)) }
605-
return res.broadcast(toShape: shape) / Tensor(count)
601+
return (value, { [shape = shapeTensor] v in
602+
let unsqueezed = v.expandingShape(at: axes.scalars.map { Int($0) })
603+
return unsqueezed.broadcast(toShape: shape) / Tensor(count)
606604
})
607605
}
608606
}

0 commit comments

Comments
 (0)