Skip to content

[AutoDiff] Add Builtin.autodiffGet(JVP|VJP) for extracting AD associated functions #21144

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 10, 2018

Conversation

rxwei
Copy link
Contributor

@rxwei rxwei commented Dec 8, 2018

Add Builtin.autodiffGetJVP and Builtin.autodiffGetVJP builtins, which correspond to autodiff_function_extract [jvp] and autodiff_function_extract [vjp], respectively.

  • Extend BuiltinGenericSignatureBuilder so that clients are able to add conformance constraints via BuiltinGenericSignatureBuilder::addConformanceRequirement.

  • Add getAutoDiffGetAssociatedFunction in Builtins.cpp, which builds an AD builtin decl for a given arity, a differentiation order, and a throwing flag. Today, Builtin.autodiffGet(JVP|VJP) is only for extracting first-order derivatives of unary @autodiff functions. In the future, we could determine the arity/order/throwing-ness from suffixes to Builtin.autodiffGet(JVP|VJP), e.g. Builtin.autodiffGetJVP_Arity2_Order2_Throwing.

    func foo<T: Differentiable, U: Differentiable>(_ f: @autodiff (T) -> U) {
      let jvp: (T) -> (U, (T.TangentVector) -> U.TangentVector) = Builtin.autodiffGetJVP(f)
      let vjp: (T) -> (U, (U.CotangentVector) -> T.CotangentVector) = Builtin.autodiffGetVJP(f)
    }
  • Remove makeBoundGeneric in Builtins.cpp that I wrote earlier because we've merged makeBoundGenericType from upstream.

  • Improve AutoDiffParameterIndices to include a setAllParams argument which makes it easy to create an AutoDiffParameterIndices whose all indices are set.

  • In TypeResolver::resolveASTFunctionTypeParams and TypeResolver::resolveASTFunctionType, the checks for Differentiable-conformances of T and U in @autodiff (T) -> U is not correct when T and U are not backed by type decls. They are removed, and SR-9448 tracking a proper fix.

  • Remove ASTContext::isDifferentiable because its implementation is incorrect: It does not handle the case where the input type is a generic type parameter.

@rxwei rxwei added the tensorflow This is for "tensorflow" branch PRs. label Dec 8, 2018
@rxwei rxwei requested a review from marcrasi December 8, 2018 19:04
@rxwei
Copy link
Contributor Author

rxwei commented Dec 8, 2018

@swift-ci please test tensorflow

@@ -489,7 +486,20 @@ namespace {
InterfaceResult = generator.build(*this);
}

// SWIFT_ENABLE_TENSORFLOW
template <class G>
void addConformanceRequirement(const G &generator, ProtocolDecl *proto) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@DougGregor, is this the right thing to do? If so, shall I upstream this?

@rxwei rxwei requested a review from dan-zheng December 8, 2018 20:10
Copy link
Contributor

@dan-zheng dan-zheng left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exciting progress, getAutoDiffGetAssociatedFunction LGTM!
I wonder what work is left for generalized differentiability?

case BuiltinValueKind::AutoDiffGetVJP:
return getAutoDiffGetAssociatedFunction(Context, Id,
AutoDiffAssociatedFunctionKind::VJP,
/*order*/ 1, /*arity*/ 1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As you said, there's a mismatch between getAutoDiffGetAssociatedFunction (which supports arbitrary order and arity) and Builtin.autodiffGet(JVP|VJP) (which supports order 1 and arity 1).

Could you please expand on how you plan to handle this mismatch? Is the only solution to create many versions of the builtins like Builtin.autodiffGetJVP_Arity2_Order2_Throwing, as you noted in the PR description?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getAutoDiffGetAssociatedFunction is a builder, and Builtin.autodiffGet(JVP|VJP) is a concrete pair of functions created by that builder. When we overload, we will simply declare more builtins, which will then call the builder with a different arity/order/throwing-ness. I think this can also be simplified by extending the parser to recognize the suffixes.

@rxwei rxwei merged commit f47da5a into swiftlang:tensorflow Dec 10, 2018
@rxwei rxwei deleted the fill-autodiff-function-with-vjp branch December 10, 2018 21:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
tensorflow This is for "tensorflow" branch PRs.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants