Skip to content

[AutoDiff] automatically handle fieldwise product spaces #21575

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 31, 2018

Conversation

marcrasi
Copy link

This PR adds a @_fieldwiseProductSpace attribute on the (Co)TangentVector type aliases. This attribute declares that the (co)tangent space is the product space of the (co)tangent spaces of the fields of the struct. Knowing this, the AD pass can correctly differentiate struct_extract instructions.

Hooking this up with derived conformances for Differentiable should be as simple as making the derived conformance put @_fieldwiseProductSpace on its type aliases. (As long as the derived conformance makes fields with matching names. If the derived conformance modifies the field names, then we'll also have to teach the AD pass about that.)

With this approach, we should never need to write special code to synthesize VJPs of getters. Even in the case of cross-module public structs, the AD pass should be able to synthesize VJPs of getters on its own, because the getters are implemented using struct_extract, which the AD pass now knows how to differentiate.

@marcrasi marcrasi requested review from rxwei and dan-zheng December 31, 2018 01:05
@rxwei
Copy link
Contributor

rxwei commented Dec 31, 2018

Even in the case of cross-module public structs, the AD pass should be able to synthesize VJPs of getters on its own, because the getters are implemented using struct_extract, which the AD pass now knows how to differentiate.

It's not the case that getters are always implemented using struct_extract, because resilience allows you to not recompile a piece of client code when a library changes a property from being stored to being computed. Therefore, cross-module struct members need their property VJPs be generated unless the struct is ABI-public.

@marcrasi
Copy link
Author

marcrasi commented Dec 31, 2018

Therefore, cross-module struct members need their property VJPs be generated unless the struct is ABI-public.

Yes, and what I mean is that the AD pass can now synthesize these property VJPs while compiling the library, without additional special-case code in the AD pass. Because when it's compiling the library it sees the implementation of the getter in terms of struct_extract and knows how to differentiate that.

@rxwei
Copy link
Contributor

rxwei commented Dec 31, 2018

That makes sense. Would that require the AD pass to scan through every struct in the current module, find the ones that are Differentiable where the CotangentVector is marked @_fieldwiseProductSpace, and create differentiation tasks for every stored property getter in those structs?

@rxwei
Copy link
Contributor

rxwei commented Dec 31, 2018

Does that also mean that we should handle a special case when calculating symbols for property getters when we encounter a Differentiable struct whose TangentVector/CotangentVector is marked @_fieldwiseProductSpace?

@rxwei
Copy link
Contributor

rxwei commented Dec 31, 2018

I think the solution is excellent for non-public and ABI public structs, BTW. We just need to figure out some details of supporting public structs. Right now, it's unclear whether VJP synthesis at the AST level is easier and would result in fewer special cases here and there for public structs.

@marcrasi
Copy link
Author

Would that require the AD pass to scan through every struct in the current module, find the ones that are Differentiable where the CotangentVector is marked @_fieldwiseProductSpace, and create differentiation tasks for every stored property getter in those structs?

Yeah, we would need something to trigger differentiation. This situation seems the same as the current situation with differentiating unserialized cross-module functions, and we could handle them in similar ways.

For example, right now you need to put @differentiable() on an unserialized function if you want to differentiate it from a different module. Doing the same thing on the stored property should work for differentiating the property from a different module, though I haven't tested it.

Some kind of scan like you describe could also apply to both cases. The scan could go through all the functions (including getters) and differentiate them when they are differentiable. Getters would only be differentiable when the appropriate vector spaces are marked @_fieldwiseProductSpace.

Does that also mean that we should handle a special case when calculating symbols for property getters when we encounter a Differentiable struct whose TangentVector/CotangentVector is marked @_fieldwiseProductSpace?

If we do a scanning solution, yes. But this also applies to unserialized functions, so it's not a unique problem for getters.

With the "you must mark everything @differentiable()" solution, the existing symbol generation code should handle it because it'll see that the getter has @differentiable() and generate symbols accordingly. (Though I have also not tested this).

@rxwei
Copy link
Contributor

rxwei commented Dec 31, 2018

Cool. I think we can make Differentiable synthesis put a @differentiable attribute on each stored property when Cotan/Tan gets @_fieldwiseProductSpace.

@marcrasi
Copy link
Author

@swift-ci please test tensorflow

@marcrasi marcrasi merged commit 02cd429 into swiftlang:tensorflow Dec 31, 2018
@marcrasi marcrasi deleted the automatic-struct-extract-diff branch December 31, 2018 04:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants