-
Notifications
You must be signed in to change notification settings - Fork 10.5k
[AutoDiff] Add Builtin.autodiffGet(JVP|VJP)
for extracting AD associated functions
#21144
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AutoDiff] Add Builtin.autodiffGet(JVP|VJP)
for extracting AD associated functions
#21144
Conversation
@swift-ci please test tensorflow |
@@ -489,7 +486,20 @@ namespace { | |||
InterfaceResult = generator.build(*this); | |||
} | |||
|
|||
// SWIFT_ENABLE_TENSORFLOW | |||
template <class G> | |||
void addConformanceRequirement(const G &generator, ProtocolDecl *proto) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@DougGregor, is this the right thing to do? If so, shall I upstream this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Exciting progress, getAutoDiffGetAssociatedFunction
LGTM!
I wonder what work is left for generalized differentiability?
case BuiltinValueKind::AutoDiffGetVJP: | ||
return getAutoDiffGetAssociatedFunction(Context, Id, | ||
AutoDiffAssociatedFunctionKind::VJP, | ||
/*order*/ 1, /*arity*/ 1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As you said, there's a mismatch between getAutoDiffGetAssociatedFunction
(which supports arbitrary order and arity) and Builtin.autodiffGet(JVP|VJP)
(which supports order 1 and arity 1).
Could you please expand on how you plan to handle this mismatch? Is the only solution to create many versions of the builtins like Builtin.autodiffGetJVP_Arity2_Order2_Throwing
, as you noted in the PR description?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
getAutoDiffGetAssociatedFunction
is a builder, and Builtin.autodiffGet(JVP|VJP)
is a concrete pair of functions created by that builder. When we overload, we will simply declare more builtins, which will then call the builder with a different arity/order/throwing-ness. I think this can also be simplified by extending the parser to recognize the suffixes.
Add
Builtin.autodiffGetJVP
andBuiltin.autodiffGetVJP
builtins, which correspond toautodiff_function_extract [jvp]
andautodiff_function_extract [vjp]
, respectively.Extend
BuiltinGenericSignatureBuilder
so that clients are able to add conformance constraints viaBuiltinGenericSignatureBuilder::addConformanceRequirement
.Add
getAutoDiffGetAssociatedFunction
inBuiltins.cpp
, which builds an AD builtin decl for a given arity, a differentiation order, and a throwing flag. Today,Builtin.autodiffGet(JVP|VJP)
is only for extracting first-order derivatives of unary@autodiff
functions. In the future, we could determine the arity/order/throwing-ness from suffixes toBuiltin.autodiffGet(JVP|VJP)
, e.g.Builtin.autodiffGetJVP_Arity2_Order2_Throwing
.Remove
makeBoundGeneric
inBuiltins.cpp
that I wrote earlier because we've mergedmakeBoundGenericType
from upstream.Improve
AutoDiffParameterIndices
to include asetAllParams
argument which makes it easy to create anAutoDiffParameterIndices
whose all indices are set.In
TypeResolver::resolveASTFunctionTypeParams
andTypeResolver::resolveASTFunctionType
, the checks forDifferentiable
-conformances ofT
andU
in@autodiff (T) -> U
is not correct whenT
andU
are not backed by type decls. They are removed, and SR-9448 tracking a proper fix.Remove
ASTContext::isDifferentiable
because its implementation is incorrect: It does not handle the case where the input type is a generic type parameter.