-
Notifications
You must be signed in to change notification settings - Fork 10.5k
[TF] Fix gradients in sum(squeezinAxes:) and mean(squeezinAxes:) #24164
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
sum(squeezinAxes:) and mean(squeezinAxes:) were throwing an error during the bawckward pass because the gradients weren't unsqueezed before being broadcast. Note that this could be refactored nicely if we had a function that took a list of ints for `expandingShape`. Second note: I may be wrong, but it seems like `_vjpMean(squeezingAxes axes: [Int])` is never used and only the Tensor<Int32> version is.
Reviewing now. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be nice to add tests and an expandingShape(at shapeIndices: [Int])
API, but those needn't block progress.
@swift-ci Please test tensorflow |
I can work on an |
That would be great, if you have time!
Please add tests to |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
That'd be great. You can check it into tensorflow/swift-apis if that's easier! |
@swift-ci Please test tensorflow |
There are actually commented tests for all squeezing things in the file you mentioned Dan. Is there any reason for that? |
Great catch! This PR fixes those concerns. Are you in a position to uncomment the tests and run tests locally? If not, I can do so. Please let me know. |
I can update with the uncommented tests, just not sure how to run them locally. |
Okay. Please update the tests, and I will verify that they pass locally, then trigger CI. Thanks! |
Done, when you have a bit of time, please tell me how you can quickly test one of those files so that I don't bother you next time ;) |
Thanks! Running tests requires a local build of the Swift compiler - if you don't have a local build, quickly testing is difficult. If you do have a local build, you can run tests using To run this particular test, I run the following (you may need to replace
|
Verified that tests pass locally. Triggering CI. |
@swift-ci Please test tensorflow |
@swift-ci Please test tensorflow linux |
Thanks all! |
…ftlang#24164) * [TF] Fix gradients in sum(squeezinAxes:) and mean(squeezinAxes:) sum(squeezinAxes:) and mean(squeezinAxes:) were throwing an error during the bawckward pass because the gradients weren't unsqueezed before being broadcast. Note that this could be refactored nicely if we had a function that took a list of ints for `expandingShape`. Second note: I may be wrong, but it seems like `_vjpMean(squeezingAxes axes: [Int])` is never used and only the Tensor<Int32> version is. * Remove unused `_vjpMean` function. * Update Gradients.swift * Add test * Minor edit for consistency.
sum(squeezinAxes:) and mean(squeezinAxes:) were throwing an error during the backward pass because the gradients weren't unsqueezed before being broadcast, this PR fixes it.
Note that this could be refactored nicely if we had a function that took a list of ints for
expandingShape
.Second note: I may be wrong, but it seems like
_vjpMean(squeezingAxes axes: [Int])
is never used and only the Tensor version is.