[TF] Reimplement unbroadcast using on-host axis calculation for performance. #24907
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
The inefficiency of
unbroadcast(toShape:)
,unbroadcast(to:)
, andunbroadcast(like:)
has caused significant performance problems during model training because it's performing a lot of TensorFlow operations to achieve axis calculation. We were forced to implement it this way in the early GPE era when neither send/receive nor per-op dispatch was available.This PR reimplements the unbroadcast operations in terms of host-side logic to compute axes to reduce along. This significantly reduces the TensorFlow opreation dispatch overhead. The base implementation changed from
broadcast(toShape:)
tobroadcast(to:)
.With the new implementation, differentiating broadcasting operators is 37% faster (see simple test script here).
Note:
unbroadcast(to:)
method.unbroadcast(to:)
and improve derivative performance. #24408 that usesRaw.broadcastGradientArgs(s0:s1:)
is still necessary for broadcasting binary operations to become faster.TODO:
unbroadcast(toShape:)
tests added by [AutoDiff] Add more Tensorbroadcast
/unbroadcast
differentiation tests. #24899 to useunbroadcast(to:)
, sinceunbroadcast(to:)
is now the base implementation.