Skip to content

[AutoDiff] Add SIL differentiability witnesses. #27487

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

dan-zheng
Copy link
Contributor

@dan-zheng dan-zheng commented Oct 2, 2019

SIL differentiability witnesses are a new top-level SIL construct mapping "original" SIL functions to derivatives. They will replace SIL function [differentiable] attributes, additionally enabling cross-module retroactive derivative registration.

SIL differentiability witnesses have the following components:

  • Original SILFunction.
  • Linkage.
  • Parameter indices (IndexSubset).
  • Result indices (IndexSubset).
  • Derivative generic signature (optional).
  • JVP function (optional).
  • VJP function (optional).

Example syntax:

sil_differentiability_witness hidden [parameters 0 1] [results 0] [where τ_0_0 : _Differentiable] @foo : $@convention(thin) <τ_0_0> (@in_guaranteed τ_0_0, Float) -> @out τ_0_0 {
  jvp: @AD__foo__jvp_src_0_wrt_0 : $@convention(thin) <τ_0_0 where τ_0_0 : _Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector, Float) -> @out τ_0_0.TangentVector)
  vjp: @AD__foo__vjp_src_0_wrt_0 : $@convention(thin) <τ_0_0 where τ_0_0 : _Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector) -> (@out τ_0_0.TangentVector, Float))
}

This patch adds the SILDifferentiabilityWitness data structure, along with parsing, printing, verification, and serialization (including lookup by key).

The master issue TF-866 tracks follow-up, including SILGen/IRGen/differentiation transform changes.


Early forum discussion.

@dan-zheng dan-zheng added the tensorflow This is for "tensorflow" branch PRs. label Oct 2, 2019
@dan-zheng dan-zheng requested a review from rxwei October 2, 2019 07:47
rxwei and others added 6 commits October 3, 2019 14:40
- Printing compiles but may require changes to be parseable. Namely,
  if there is no suitable utility for parsing standalone generic
  signatures, changes are needed.
- Parsing is a stub.
- Note: it is difficult to test parsing/printing without generating
  SILDifferentiabilityWitness instances.
`parameters (0, 1, ...) results (0, 1, ...) where <...>`
@dan-zheng dan-zheng force-pushed the sil-differentiability-witness branch from 57e9634 to bac0181 Compare October 3, 2019 21:42
@dan-zheng dan-zheng force-pushed the sil-differentiability-witness branch from bac0181 to d2c6ab5 Compare October 11, 2019 05:07
@dan-zheng dan-zheng force-pushed the sil-differentiability-witness branch from d2c6ab5 to 843b631 Compare October 11, 2019 05:11
@dan-zheng dan-zheng changed the title [WIP] [AutoDiff] Add SIL differentiability witnesses. [AutoDiff] Add SIL differentiability witnesses. Oct 11, 2019
@dan-zheng dan-zheng marked this pull request as ready for review October 11, 2019 05:16
Copy link
Contributor

@rxwei rxwei left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Because type signatures are long and sparse, I'd suggest bringing the configuration to the front of the declaration and using square brackets for them, for example:

    sil_differentiability_witness hidden [parameters 0 1] [results 0] [where τ_0_0 : _Differentiable] @foo : $@convention(thin) <τ_0_0> (@in_guaranteed τ_0_0, Float) -> @out τ_0_0 {
      jvp: @AD__foo__jvp_src_0_wrt_0 : $@convention(thin) <τ_0_0 where τ_0_0 : _Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector, Float) -> @out τ_0_0.TangentVector)
      vjp: @AD__foo__vjp_src_0_wrt_0 : $@convention(thin) <τ_0_0 where τ_0_0 : _Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector) -> (@out τ_0_0.TangentVector, Float))
    }
    
  2. Parts of what SILDifferentiabilityWitnessKey stores could have been made a more general concept: AutoDiffConfig. How about defining a AutoDiffConfig in AutoDiff.h to store parameter indices, result indices, and the derivative generic signature? Then SILDifferentiabilityWitnessKey can just be defined as std::pair<StringRef, AutoDiffConfig>.

  3. When the function name of a sil_differentiability_witness is demangleable, consider printing a descriptive comment above the witness declaration, for example:

    // differentiability witness for closure #1 in foo<A>(_:_:) 
    sil_differentiability_witness hidden [parameters 0 1] [results 0] [where τ_0_0 : _Differentiable] @someDemangleableName : $@convention(thin) <τ_0_0> (@in_guaranteed τ_0_0, Float) -> @out τ_0_0 {

@rxwei rxwei requested a review from jckarter October 11, 2019 20:44
Print original function name in comment.

```
// differentiability witness for foo
sil_differentiability_witness hidden [parameters 0 1] [results 0] [where τ_0_0 : _Differentiable] @foo
```
The contents of `AutoDiffConfig` are all uniqued, so uniquing the product
does not make sense.
Unserialized, to be used for diagnostics.
Will revisit later when revamping the differentiation transform.
@dan-zheng dan-zheng force-pushed the sil-differentiability-witness branch from 4ddcd3f to ad6b7aa Compare October 13, 2019 01:16
@dan-zheng dan-zheng requested a review from rxwei October 13, 2019 01:19
@dan-zheng dan-zheng force-pushed the sil-differentiability-witness branch from 771143a to c3959ad Compare October 13, 2019 05:12
Manually verified parsing/printing.
Chose not to add additional test for now to keep the test small.
Copy link
Contributor

@rxwei rxwei left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there serialization tests?

Copy link
Contributor

@rxwei rxwei left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The implementation LGTM, but I just wanted to make sure serialization is tested.

@dan-zheng dan-zheng force-pushed the sil-differentiability-witness branch from ef5f53c to 7a070f0 Compare October 13, 2019 07:35
@dan-zheng dan-zheng force-pushed the sil-differentiability-witness branch from 7a070f0 to 69209be Compare October 13, 2019 07:42
Note: deserialization does not work when SIL differentiability witness
references bodyless function declarations.
@dan-zheng
Copy link
Contributor Author

@swift-ci Please test tensorflow

Verification assertion messages should appear on the differentiability witness.
@dan-zheng
Copy link
Contributor Author

@swift-ci Please test tensorflow

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
tensorflow This is for "tensorflow" branch PRs.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants