You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
[TF] Fix gradients in sum(squeezinAxes:) and mean(squeezinAxes:) (#24164)
* [TF] Fix gradients in sum(squeezinAxes:) and mean(squeezinAxes:)
sum(squeezinAxes:) and mean(squeezinAxes:) were throwing an error during the bawckward pass because the gradients weren't unsqueezed before being broadcast.
Note that this could be refactored nicely if we had a function that took a list of ints for `expandingShape`.
Second note: I may be wrong, but it seems like `_vjpMean(squeezingAxes axes: [Int])` is never used and only the Tensor<Int32> version is.
* Remove unused `_vjpMean` function.
* Update Gradients.swift
* Add test
* Minor edit for consistency.
0 commit comments