-
Notifications
You must be signed in to change notification settings - Fork 10.5k
[AutoDiff] Revamp PullbackInfo
#26587
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AutoDiff] Revamp PullbackInfo
#26587
Conversation
4694266
to
725308a
Compare
@swift-ci please test tensorflow |
608844b
to
b9637c9
Compare
b9637c9
to
1a7e8b6
Compare
return false; | ||
} | ||
|
||
void LinearMapInfo::populateLinearMapStructDeclarationFields( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm fine with checking this function to unblock things, but it should be broken down into multiple functions to reduce nesting.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No worries I can make it part of this PR. Currently rebuilding my Xcode build, so will push a separate commit to make this change soon once it finishes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added this change in my latest commit, separated out the code of when I find an apply instruction.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
0fa5f61
to
332c1f5
Compare
332c1f5
to
81b8592
Compare
@swift-ci please test tensorflow |
PullbackInfo
->LinearMapInfo
With the introduction of Forward Mode Differentiation, the name
PullbackInfo
is not correct anymore. To recall, it stores information about the JVP/VJP generation. thus, a more general name ofLinearMapInfo
more appropriately represents what it stores.Specifically, it stores the differential and pullback functions, which are linear functions. We map those functions to the original apply calls, which is where
LinearMapInfo
andlinearMapStructs
comes from.Revamp
linearMapStruct
creation to occur before JVP/VJP emissionIn my forward mode PR #26057 , I experienced an LLVM IR bug. What was going on is that it was caching the linear map structs before the linear functions were added. Those functions were being added while we were visiting the relevant instructions in JVP/VJP emission.
The solution to avoid this caching problem was to bring the part where we add the linear map functions earlier to when we create the
LinearMapInfo
. Thus, when the JVP/VJP emission pass occurs, it will cache the full linear map struct.