-
Notifications
You must be signed in to change notification settings - Fork 10.5k
[AutoDiff] populate diff witnesses during differentiation #28402
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AutoDiff] populate diff witnesses during differentiation #28402
Conversation
@swift-ci please test tensorflow |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Incremental progress sounds good!
// Get lowered argument indices. | ||
auto *paramIndices = A->getParameterIndices(); | ||
assert(paramIndices && "Parameter indices should have been resolved"); | ||
auto *loweredParamIndices = autodiff::getLoweredParameterIndices( | ||
paramIndices, decl->getInterfaceType()->castTo<AnyFunctionType>()); | ||
// NOTE(TF-893): Extending capacity is necessary when `origSilFnType` has |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if the only case that triggers this "capacity extending due to captured variables" logic is the @differentiable
attribute on func original
nested in func differentiableFunction(from:)
?
Full context: there are known issues regarding differentiation and local variable capture (TF-881). @rxwei mentioned disallowing @differentiable
attribute on nested functions for now and creating a builtin to support func differentiableFunction(from:)
. One known user of differentiableFunction(from:)
is the custom differentiation tutorial.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that's the only thing (in the stdlib + swift-apis + tests, at least) that triggers this problem.
It would indeed be nice to remove that and forbid @differentiable
on nested functions for now.
/// Sets the differentiability witness JVP and VJP to the JVP and VJP in `attr`. | ||
/// | ||
/// `attr` must have a JVP and VJP. | ||
static void fillDifferentiabilityWitness(SILModule &module, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor: how about settling on the verb "canonicalize" instead of "fill"/"populate"/"process" for [differentiable]
attributes, differentiability witnesses, differentiable_function
instructions?
processDifferentiableAttribute
and processDifferentiableFunctionInst
can also be renamed later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 (but only weakly)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good. Done for this one.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks for pushing this forward :) I think we are really really close.
The main point of this PR is to populate differentiability witnesses during the differentiation pass, while still using attributes for everything, so that we can incrementally switch everything over to using witnesses.
A few adjustments in other things were necessary to get this to work:
I'm going to run a toolchain build and swift-apis tests on this before merging, because I'm not super confident that all the assertions I added will always pass.