Skip to content

[AutoDiff] Add more Tensor broadcast/unbroadcast differentiation tests. #24899

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

bartchr808
Copy link
Contributor

Similar to the already existing set of tests for broadcast(toShape:)/unbroadcast(toShape:) in that this adds the same type of tests, but calling broadcast(to:)/unbroadcast(to:) and broadcast(like:)/unbroadcast(like:) instead.

@bartchr808 bartchr808 added the tensorflow This is for "tensorflow" branch PRs. label May 19, 2019
@bartchr808 bartchr808 requested review from rxwei and dan-zheng May 19, 2019 07:27
@bartchr808
Copy link
Contributor Author

@swift-ci please test tensorflow

1 similar comment
@bartchr808
Copy link
Contributor Author

@swift-ci please test tensorflow

rxwei added a commit to rxwei/swift that referenced this pull request May 20, 2019
…rmance.

The inefficiency of `unbroadcast(toShape:)`, `unbroadcast(to:)`, and `unbroadcast(like:)` has caused significant performance problems during model training because it's performing a lot of TensorFlow operations to achieve axis calculation. We were forced to implement it this way in the early GPE era when neither send/receive nor per-op dispatch was available.

This PR reimplements the unbroadcast operations in terms of host-side logic to compute axes to reduce along. This significantly reduces the TensorFlow opreation dispatch overhead. The base implementation changed from `broadcast(toShape:)` to `broadcast(to:)`.

With the new implementation, differentiating broadcasting operators is 37% faster (see simple test script [here](https://gist.github.com/rxwei/e1488cac5379ba2bc3aff7490e18158f)).

Note:
- Since we now rely on the TensorFlow runtime less, more precondition checks and assertions are added to the newly implemented `unbroadcast(to:)` method.
- The part of swiftlang#24408 that uses `Raw.broadcastGradientArgs(s0:s1:)` is still necessary for broadcasting binary operations to become faster.

TODO:
- Change `unbroadcast(toShape:)` tests added by swiftlang#24899 to use `unbroadcast(to:)`, since `unbroadcast(to:)` is now the base implementation.
rxwei added a commit that referenced this pull request May 20, 2019
…rmance. (#24907)

The inefficiency of `unbroadcast(toShape:)`, `unbroadcast(to:)`, and `unbroadcast(like:)` has caused significant performance problems during model training because it's performing a lot of TensorFlow operations to achieve axis calculation. We were forced to implement it this way in the early GPE era when neither send/receive nor per-op dispatch was available.

This PR reimplements the unbroadcast operations in terms of host-side logic to compute axes to reduce along. This significantly reduces the TensorFlow opreation dispatch overhead. The base implementation changed from `broadcast(toShape:)` to `broadcast(to:)`.

With the new implementation, differentiating broadcasting operators is 37% faster (see simple test script [here](https://gist.github.com/rxwei/e1488cac5379ba2bc3aff7490e18158f)).

Note:
- Since we now rely on the TensorFlow runtime less, more precondition checks and assertions are added to the newly implemented `unbroadcast(to:)` method.
- The part of #24408 that uses `Raw.broadcastGradientArgs(s0:s1:)` is still necessary for broadcasting binary operations to become faster.

TODO:
- Change `unbroadcast(toShape:)` tests added by #24899 to use `unbroadcast(to:)`, since `unbroadcast(to:)` is now the base implementation.
@bartchr808
Copy link
Contributor Author

Closing PR due to refactoring moving Tensor to tensorflow/swift-apis found in this PR.

@bartchr808 bartchr808 closed this Jun 13, 2019
@bartchr808 bartchr808 deleted the TF-509-tensor-broadcast-differentiable branch June 13, 2019 21:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
tensorflow This is for "tensorflow" branch PRs.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant