-
Notifications
You must be signed in to change notification settings - Fork 10.5k
[AutoDiff] automatically handle fieldwise product spaces #21575
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AutoDiff] automatically handle fieldwise product spaces #21575
Conversation
It's not the case that getters are always implemented using |
Yes, and what I mean is that the AD pass can now synthesize these property VJPs while compiling the library, without additional special-case code in the AD pass. Because when it's compiling the library it sees the implementation of the getter in terms of |
That makes sense. Would that require the AD pass to scan through every struct in the current module, find the ones that are |
Does that also mean that we should handle a special case when calculating symbols for property getters when we encounter a |
I think the solution is excellent for non-public and ABI public structs, BTW. We just need to figure out some details of supporting public structs. Right now, it's unclear whether VJP synthesis at the AST level is easier and would result in fewer special cases here and there for public structs. |
Yeah, we would need something to trigger differentiation. This situation seems the same as the current situation with differentiating unserialized cross-module functions, and we could handle them in similar ways. For example, right now you need to put Some kind of scan like you describe could also apply to both cases. The scan could go through all the functions (including getters) and differentiate them when they are differentiable. Getters would only be differentiable when the appropriate vector spaces are marked
If we do a scanning solution, yes. But this also applies to unserialized functions, so it's not a unique problem for getters. With the "you must mark everything |
Cool. I think we can make |
@swift-ci please test tensorflow |
This PR adds a
@_fieldwiseProductSpace
attribute on the (Co)TangentVector type aliases. This attribute declares that the (co)tangent space is the product space of the (co)tangent spaces of the fields of the struct. Knowing this, the AD pass can correctly differentiatestruct_extract
instructions.Hooking this up with derived conformances for Differentiable should be as simple as making the derived conformance put
@_fieldwiseProductSpace
on its type aliases. (As long as the derived conformance makes fields with matching names. If the derived conformance modifies the field names, then we'll also have to teach the AD pass about that.)With this approach, we should never need to write special code to synthesize VJPs of getters. Even in the case of cross-module public structs, the AD pass should be able to synthesize VJPs of getters on its own, because the getters are implemented using
struct_extract
, which the AD pass now knows how to differentiate.