-
Notifications
You must be signed in to change notification settings - Fork 10.5k
[AutoDiff] Canonicalize SIL type for JVP/VJP methods. #24775
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AutoDiff] Canonicalize SIL type for JVP/VJP methods. #24775
Conversation
Canonicalize JVP/VJP type calculation for methods so that JVP/VJP methods return a linear map that take/return self's tangent/cotangent last, instead of first. This matches the type calculation logic for top-level functions. Changes: - Remove method-specific logic from `SILFunctionType::getAutoDiffAssociatedFunctionType`. Handle `self` like a normal differentiation parameter. - Thunk user-defined JVP/VJP methods, reordering the position of self's tangent/cotangent in returned linear maps. - Relevant code: `SILGenModule::getOrCreateAutoDiffAssociatedFunctionReorderingThunk` and `SILGenFunction::getOrCreateAutoDiffLinearMapReorderingThunk`. - Change type computation for JVP/VJP method protocol witnesses. - Relevant code: `getSILFunctionType` in SILFunctionType.cpp and `SILGenFunction::emitProtocolWitness`. Move some functions to a common location: - `AnyFunctionType::getAutoDiffOriginalFunctionType` - `autodiff::getAutoDiffFunctionLinkage`
@@ -84,7 +84,11 @@ | |||
|
|||
#include "SILGen.h" | |||
#include "SILGenFunction.h" | |||
// SWIFT_ENABLE_TENSORFLOW | |||
#include "SILGenFunctionBuilder.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's possible to avoid introducing these dependencies by moving the "thunk function creation" logic to SILGenThunk.cpp.
This is done for reabstraction thunks:
- Thunk
SILFunction
creation is done in SILGenThunk.cpp inSILGenModule::getOrCreateReabstractionThunk
. SILGenThunk.cpp includes the appropriate headers. - Thunk body creation is done here in SILGenPoly.cpp.
Consider refactoring at some point.
- Fix test/Serialization/differentiable_attr.swift. - Use `CHECK` and `CHECK-NEXT` instead of `CHECK-DAG`. - Remove dead code from `SILGenFunction::emitProtocolWitness`. - Gardening.
@swift-ci Please test tensorflow |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have not read through all the type calculation and thunking in detail yet, because I have a high-level comment that might simplify some of it.
Can you make the AST type of the associated functions have self last? i.e. make AnyFunctionType::getAutoDiffAssociatedFunctionType return types with self last?
I haven't thought this through carefully, but I think that if you did that:
- This order wouldn't be exposed to the user anywhere, because the user never gets direct access to that AST type. (Except: You might have to make a small ordering adjustment in TypeCheckAttr).
- This would eliminate the reordering logic in
SILFunctionType.cpp
, because the order would already be correct. - This would eliminate the reordering logic in
SILGenFuncion::emitProtocolWitness
, because the AST types that get there would already be in the right order. - Also, this would be overall safer and less likely to break in the future, because every internal layer would have the same order. Only the very top layer that's exposed to the user would have a different order.
- Add doc comments to `getAutoDiffFunctionLinkage`. - Optimize "curry level unwrapping" in `AnyFunctionType::getAutoDiffAssociatedFunctionType`.
063f4fd
to
7426da2
Compare
Good idea! We discussed in person and I discussed with @rxwei - we agree that computing self-parameter-reordered JVP/VJP types on AST function types instead of SIL function types is safer and simpler. Done in 063f4fd, removing 100+ lines from SILGen and adding <10 lines to AST/Sema. |
@swift-ci Please test tensorflow |
Address review feedback from @marcrasi. Computing self-parameter-reordered JVP/VJP types on AST function types instead of SIL function types is safer and simpler.
7426da2
to
34dcfe3
Compare
@swift-ci Please test tensorflow |
Canonicalize JVP/VJP type calculation for methods so that JVP/VJP methods return
a linear map that take/return self's tangent/cotangent last, instead of first.
This matches the type calculation logic for top-level functions.
Changes:
SILFunctionType::getAutoDiffAssociatedFunctionType
.Handle
self
like a normal differentiation parameter.tangent/cotangent in returned linear maps.
SILGenModule::getOrCreateAutoDiffAssociatedFunctionReorderingThunk
andSILGenFunction::getOrCreateAutoDiffLinearMapReorderingThunk
.getSILFunctionType
in SILFunctionType.cpp andSILGenFunction::emitProtocolWitness
.Move some functions to a common location:
AnyFunctionType::getAutoDiffOriginalFunctionType
autodiff::getAutoDiffFunctionLinkage
This is a step towards refactoring the differentiation transform (iterative
autodiff_function
worklist redesign), which will unblock control-flow support.