-
Notifications
You must be signed in to change notification settings - Fork 10.5k
[AutoDiff] Revamp differentiation transform. #24845
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AutoDiff] Revamp differentiation transform. #24845
Conversation
e9d033a
to
415a259
Compare
@swift-ci Please test tensorflow |
@@ -525,6 +525,15 @@ bool ReabstractionInfo::prepareAndCheck(ApplySite Apply, SILFunction *Callee, | |||
return false; | |||
} | |||
|
|||
// SWIFT_ENABLE_TENSORFLOW | |||
// Disable specialization for instructions that are operands of |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change is necessary to prevent post-differentiation SIL passes from generating invalid SIL.
In particular, this change disables specialization of operands to autodiff_function
.
It may be possible to teach the specializer how to handle autodiff_function
instructions. I don't know much work that would take, it seems difficult.
@@ -74,7 +74,15 @@ static bool foldInverseReabstractionThunks(PartialApplyInst *PAI, | |||
SILInstruction *SILCombiner::visitPartialApplyInst(PartialApplyInst *PAI) { | |||
// partial_apply without any substitutions or arguments is just a | |||
// thin_to_thick_function. | |||
if (!PAI->hasSubstitutions() && (PAI->getNumArguments() == 0)) { | |||
if (!PAI->hasSubstitutions() && (PAI->getNumArguments() == 0) && |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change is necessary to prevent post-differentiation SIL passes from generating invalid SIL.
This change is rather innocuous and can be upstreamed to master
. It just adds a previously unexercised check.
// redesign: currently, each original function + attribute pair is mapped | ||
// only to one invoker. | ||
/* | ||
DifferentiationInvoker indirect(ai, this->original, attr); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: indirect differentiation invokers are currently not handled post-refactoring. I've disabled the test for now.
- Directly generate primal code in VJP functions. - Rename `PrimalGenCloner` to `VJPEmitter`. - This unblocks control-flow support. Primal data structures can be determined based on activity analysis. - Refactor differentiation transform to use iterative `autodiff_function` instruction worklist instead of a fixed `DifferentiationTask` worklist. - `VJPEmitter::visitApply` now creates `autodiff_function` and `autodiff_function_extract` instructions instead of directly emitting associated function references. - An iterative loop canonicalizes `autodiff_function` instructions, filling in missing associated functions. - Remove unnecessary auxiliary data structures. - `DifferentiationTask`: replace with map from original function and `[differentiable]` attribute to `DifferentiationInvoker`. - `PrimalGen`, `AdjointGen`: replace with `autodiff_function` iterative loop.
415a259
to
1efbac6
Compare
1efbac6
to
9382eaf
Compare
@@ -6,14 +6,13 @@ | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please add a note in this file to say what's unexpected/broken?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm actually not sure what's the expected behavior for these tests. I'll debug a bit to see what's going on.
At a glance, the changes seem unideal. The old diagnostics were more informative and appeared on the gradient
user invocation.
--- a/test/AutoDiff/autodiff_indirect_diagnostics.swift
+++ b/test/AutoDiff/autodiff_indirect_diagnostics.swift
@@ -6,14 +6,13 @@
// Test unmet generic requirements.
-// expected-error @+1 {{function is not differentiable}}
@differentiable
-// expected-note @+1 {{when differentiating this function definition}}
func generic<T: Differentiable & FloatingPoint>(_ x: T) -> T {
+ // expected-error @+2 {{expression is not differentiable}}
// expected-note @+1 {{member is not differentiable because the corresponding protocol requirement is not '@differentiable'}}
return x + 1
}
-_ = gradient(at: 1.0, in: generic) // expected-error {{function is not differentiable}}
+_ = gradient(at: 1.0, in: generic)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't quite make sense of the diagnostic behavior changes yet.
We plan to change/fix diagnostic behavior soon, so it would be good to revisit these changes as part of that follow-up.
Ready for re-review. cc @rxwei |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The patch LGTM overall. Huge congrats on getting this done!
There are two necessary improvements I think should be done before merging:
- The
AdditiveArithmetic.zero
emission logic in a thunk builder should be unified with what's already inAdjointEmitter
. - I suspect there's a performance regression caused by extra
autodiff_function
andautodiff_function_extract
instructions being generated, which could pose a barrier for later optimizations. Folding trivialautodiff_function
-autodiff_function_extract
can be trivially done near the logic that canonicalizesautodiff_function
instructions in the main loop.
Other than those, we should talk about fixing diagnostics for indirect calls soon.
- Create common function `emitZeroIntoBuffer`. - Shared by `AdjointEmitter::emitZeroIndirect` and `buildZeroArgument` lambda. - Minor naming/style changes.
Big thanks for @rxwei for the thorough review and advice! |
064e61e
to
1433de0
Compare
Fold `autodiff_function_extract` users of `autodiff_function` instructions, directly replacing them with operands of the `autodiff_function` instruction. If the `autodiff_function` instruction has no non-`autodiff_function_extract` users, delete the instruction itself after folding. The `differentiation-skip-folding-autodiff-function-extraction` flag disables folding for SIL testing purposes.
1433de0
to
2b43542
Compare
@swift-ci Please test tensorflow |
PrimalGenCloner
toVJPEmitter
.Primal data structures can be generated based on activity analysis.
autodiff_function
instruction worklist instead of a fixed
DifferentiationTask
worklist.VJPEmitter::visitApply
now createsautodiff_function
andautodiff_function_extract
instructions instead of directly emittingassociated function references.
autodiff_function
instructions,promoting them to
@differentiable
function-typed values.autodiff_function_extract
folding optimization.autodiff_function_extract
users ofautodiff_function
instructions,directly replacing them with operands of the
autodiff_function
instruction.
autodiff_function
instruction has onlyautodiff_function_extract
users, delete the instruction itself after folding.
DifferentiationTask
: replace with map from[differentiable]
attributesto
DifferentiationInvoker
.PrimalGen
,AdjointGen
: replace withautodiff_function
iterative loop.