-
Notifications
You must be signed in to change notification settings - Fork 10.5k
[TF] Remove unbroadcast(to:)
and improve derivative performance.
#24408
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@pschuh I don't have time to look into it in the next couple of days, but it'd be great if you could take a look! |
In the pullback for operators that broadcast, use `Raw.broadcastGradientArgs(s0:s1:)` to compute reduction indices instead of using the inefficient `unbroadcast(to:)`. `unbroadcast(to:)` was introduced only for defining derivatives for broadcasting operators and has no practical use, so now we remove it. Operators affected: - `Tensor.+(_:_:)` - `Tensor.-(_:_:)` - `Tensor.*(_:_:)` - `Tensor./(_:_:)` - `min(_:_:)` - `max(_:_:)` - `pow(_:_:)`
4511398
to
0902283
Compare
@@ -267,14 +278,14 @@ extension Tensor where Scalar : TensorFlowFloatingPoint { | |||
static func _vjpSubtract( | |||
lhs: Tensor, rhs: Scalar | |||
) -> (Tensor, (Tensor) -> (Tensor, Scalar)) { | |||
return (lhs - rhs, { v in (v, 0 - v.sum().scalarized()) }) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is some legacy code introduced in the early days when a where
clause on @differentiable
was not supported. Now it is fixed for better.
-v.unbroadcast(toShape: rhsShape)) | ||
let (lhsAxes, rhsAxes) = | ||
Raw.broadcastGradientArgs(s0: lhsShape, s1: rhsShape) | ||
return (v.sum(squeezingAxes: lhsAxes).reshaped(toShape: lhsShape), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I haven't looked too much, but I suspect that this extra reshape is not necessary. The lhsAxes should be sufficient to recover the original shape.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That’s what I tried initially (more specifically, ‘sum(alongAxes:)’) but it didn’t work.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@pschuh @rxwei the reshape is needed for handling dimensions with size 1. For example, say you do:
// x has shape [B, 5]
// y has shape [5]
// result has shape [B, 5]
let result = x + y
In this case, the broadcast indices for the gradient wrt to y
will be [0]
and so we’ll do something like:
let yGrad = seed.sum(alongAxes: [0]) // no reshape needed.
Now, let y
have shape [1, 5]
, which still broadcasts correctly for this example. The broadcast indices will now also be the same for the gradient (i.e., [0]
). However, we need to do the reshape to recover the dimensions of size 1. Thus, the gradient needs to be computed as:
let yGrad = seed.sum(alongAxes: [0]).reshape(to: y.shape)
Having said that, I have a working implementation of these changes that I had made as part of a future swift-apis
PR. I’ll try to open a PR here for this ASAP, but haven’t gotten the chance yet because I’m traveling to ICLR this week.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To make sure we don't regress in the future, could you add a quick test case in your other PR to swift-apis
? :-)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeap, I will go ahead and add that. Given that the merge already happened, is it ok to make this change after we move stdlib to swift-apis? I'll update the two PRs doing the move tonight.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. Thanks!
…rmance. The inefficiency of `unbroadcast(toShape:)`, `unbroadcast(to:)`, and `unbroadcast(like:)` has caused significant performance problems during model training because it's performing a lot of TensorFlow operations to achieve axis calculation. We were forced to implement it this way in the early GPE era when neither send/receive nor per-op dispatch was available. This PR reimplements the unbroadcast operations in terms of host-side logic to compute axes to reduce along. This significantly reduces the TensorFlow opreation dispatch overhead. The base implementation changed from `broadcast(toShape:)` to `broadcast(to:)`. With the new implementation, differentiating broadcasting operators is 37% faster (see simple test script [here](https://gist.github.com/rxwei/e1488cac5379ba2bc3aff7490e18158f)). Note: - Since we now rely on the TensorFlow runtime less, more precondition checks and assertions are added to the newly implemented `unbroadcast(to:)` method. - The part of swiftlang#24408 that uses `Raw.broadcastGradientArgs(s0:s1:)` is still necessary for broadcasting binary operations to become faster. TODO: - Change `unbroadcast(toShape:)` tests added by swiftlang#24899 to use `unbroadcast(to:)`, since `unbroadcast(to:)` is now the base implementation.
…rmance. (#24907) The inefficiency of `unbroadcast(toShape:)`, `unbroadcast(to:)`, and `unbroadcast(like:)` has caused significant performance problems during model training because it's performing a lot of TensorFlow operations to achieve axis calculation. We were forced to implement it this way in the early GPE era when neither send/receive nor per-op dispatch was available. This PR reimplements the unbroadcast operations in terms of host-side logic to compute axes to reduce along. This significantly reduces the TensorFlow opreation dispatch overhead. The base implementation changed from `broadcast(toShape:)` to `broadcast(to:)`. With the new implementation, differentiating broadcasting operators is 37% faster (see simple test script [here](https://gist.github.com/rxwei/e1488cac5379ba2bc3aff7490e18158f)). Note: - Since we now rely on the TensorFlow runtime less, more precondition checks and assertions are added to the newly implemented `unbroadcast(to:)` method. - The part of #24408 that uses `Raw.broadcastGradientArgs(s0:s1:)` is still necessary for broadcasting binary operations to become faster. TODO: - Change `unbroadcast(toShape:)` tests added by #24899 to use `unbroadcast(to:)`, since `unbroadcast(to:)` is now the base implementation.
@rxwei Done in tensorflow/swift-apis#142. |
Re-implementation of swiftlang/swift#24408. In the pullback for operators that broadcast, use `Raw.broadcastGradientArgs(s0:s1:)` to compute reduction indices instead of using the inefficient `unbroadcast(to:)`.
In the pullback for operators that broadcast, use
Raw.broadcastGradientArgs(s0:s1:)
to compute reduction indices instead of using the inefficientunbroadcast(to:)
.unbroadcast(to:)
was introduced only for defining derivatives for broadcasting operators and has no practical use, so now we remove it.Operators affected:
Tensor.+(_:_:)
Tensor.-(_:_:)
Tensor.*(_:_:)
Tensor./(_:_:)
min(_:_:)
max(_:_:)
pow(_:_:)
TODO before merging:
Currently there's a failure on
+
(see the test being commented out). Figure out what's wrong and fix it.