-
Notifications
You must be signed in to change notification settings - Fork 137
Improved derivative performance for broadcasted operations. #142
Conversation
We have been assuming that seeds are always shape-compatible (if not same-shape) as the original result, and existing implementations have not been doing broadcasting to the shape of the original result. I think there are other reasons for it, which may lead to a different solution. |
That's interesting! I actually traced the test failure and saw that the seed shape was not matching the result shape when the add VJP was being called. Given that this was being directly invoked by the |
} | ||
let x = Tensor<Float>(ones: [1, 2, 1, 4]) | ||
let y = Tensor<Float>(ones: [4, 1, 3, 1]) | ||
let (dx, dy) = gradient(at: x, y, in: foo) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this test case, the use of gradient(at:in:)
is not valid because gradient is only mathematically defined for functions that return a scalar. We can either make foo(_:_:)
to a sum()
or use pullback(at:in:)
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's what I thought. This was actually the failing case in the other PR and that's why I copied it here but wasn't sure of the semantics of gradient
. I'll add a call to sum()
and remove the seed broadcasts then. :)
AD is fully shape-agnostic, and there's some mathematical consistency to differential operators like @inlinable
public func valueWithGradient<T, R>(
at x: T,
in f: @differentiable (T) -> Tensor<R>
) -> (value: Tensor<R>, gradient: T.TangentVector)
where T: Differentiable, R: TensorFlowFloatingPoint {
let (y, pullback) = valueWithPullback(at: x, in: f)
precondition(y.rank == 0)
return (y, pullback(Tensor<R>(1)))
} |
I also prefer that. What's the cost of preconditions? Are they removed when compiling with optimizations enabled? |
From this blog post. |
Awesome, thanks! :) I'll be away for ~30' but will make those changes once I get back. |
I'm looking into this now. I can add the precondition for |
I made all necessary changes and all tests pass locally. |
It actually does because it calls |
Sounds good. Done in the latest commit. :) |
All tests pass. Thank you! |
Re-implementation of swiftlang/swift#24408.
@rxwei the reason you were getting errors is because you were differentiating a tensor-valued function (as opposed to scalar-valued) and you were assuming that the gradient seed has the same shape as the operation result. I fixed that by broadcasting the gradient seeds before computing the gradients.