-
Notifications
You must be signed in to change notification settings - Fork 10.5k
[AutoDiff upstream] Add reverse-mode automatic differentiation. #30821
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AutoDiff upstream] Add reverse-mode automatic differentiation. #30821
Conversation
Differentiable activity analysis is a dataflow analysis which marks values in a function as varied, useful, or active (both varied and useful). Only active values need a derivative.
Add `AdjointValue`: a symbolic representation for adjoint values enabling efficient differentiation by avoiding zero materialization.
`LinearMapInfo` contains information about linear map structs and branching trace enums, which are auxiliary data structures created by the differentiation transform. These data structures are constructed in JVP/VJP functions and consumed in differential/pullback functions.
`VJPEmitter` is a cloner that emits VJP functions. It implements reverse-mode automatic differentiation, along with `PullbackEmitter`. `VJPEmitter` clones an original function, replacing function applications with VJP function applications. In VJP functions, each basic block takes a pullback struct (containing callee pullbacks) and produces a predecessor enum: these data structures are consumed by pullback functions.
`PullbackEmitter` is a visitor that emits pullback functions. It implements reverse-mode automatic differentiation, along with `VJPEmitter`. Pullback functions take derivatives with respect to outputs and return derivatives with respect to inputs. Every active value/address in an original function has a corresponding adjoint value/buffer in the pullback function. Pullback functions consume pullback structs and predecessor enums constructed by VJP functions.
@gottesmm: would you like to help review this patch, as a SILOptimizer code owner?
|
…larations. IRGenDebugInfo crash due to lack of proper mangling for AutoDiff-generated declarations: linear map structs and branching trace enums.
91197c1
to
52374bf
Compare
@swift-ci Please test |
Build failed |
Build failed |
Merging to unblock progress. Happy to address any feedback later! |
`SourceFile::addVisibleDecl` is an unnecessary API. It was upstreamed in swiftlang#30821.
Add reverse-mode automatic differentiation and supporting utilities.
The Differentiable Programming Implementation Overview provides transformation details.
Derivative emitters
VJPEmitter
VJP functions are reverse-mode derivative functions.
VJPEmitter
is a cloner that emits VJP functions.It replaces function applications with VJP function applications. In VJP functions, each basic block takes a pullback struct (containing callee pullbacks) and produces a predecessor enum (except for the exit block): these data structures are consumed by pullback functions.
PullbackEmitter
PullbackEmitter
is a visitor that emits pullback functions.Pullback functions take derivatives with respect to outputs and return derivatives with respect to inputs. Every active value/address in an original function has a corresponding adjoint value/buffer in the pullback function.
Utilities
Differentiable activity analysis
Activity analysis is a dataflow analysis which marks values in a function as varied, useful, or active (both varied and useful). Only active values need a derivative.
AdjointValue
AdjointValue
is a symbolic representation for derivatives values enabling efficient differentiation by avoiding zero materialization.LinearMapInfo
LinearMapInfo
contains information about linear map structs and branching trace enums, which are auxiliary data structures created by the differentiation transform.These data structures are constructed in JVP/VJP functions and consumed in differential/pullback functions.
Transformation example
Original function:
Generated VJP and pullback functions (high-level pseudocode):
Generated VJP and pullback functions (lower-level pseudocode):
Add simple VJP and pullback generation test.
TF-1232 (AutoDiff-generated declaration mangling) is the latest blocker for end-to-end tests.
More tests will be upstreamed later. It's hard to test things end-to-end until more upstreaming is done, so I'd like to make incremental upstreaming progress.