Skip to content

[AutoDiff upstream] Add reverse-mode automatic differentiation. #30821

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Apr 6, 2020

Conversation

dan-zheng
Copy link
Contributor

@dan-zheng dan-zheng commented Apr 6, 2020

Add reverse-mode automatic differentiation and supporting utilities.
The Differentiable Programming Implementation Overview provides transformation details.


Derivative emitters

VJPEmitter

VJP functions are reverse-mode derivative functions. VJPEmitter is a cloner that emits VJP functions.

It replaces function applications with VJP function applications. In VJP functions, each basic block takes a pullback struct (containing callee pullbacks) and produces a predecessor enum (except for the exit block): these data structures are consumed by pullback functions.

PullbackEmitter

PullbackEmitter is a visitor that emits pullback functions.

Pullback functions take derivatives with respect to outputs and return derivatives with respect to inputs. Every active value/address in an original function has a corresponding adjoint value/buffer in the pullback function.

Utilities

Differentiable activity analysis

Activity analysis is a dataflow analysis which marks values in a function as varied, useful, or active (both varied and useful). Only active values need a derivative.

AdjointValue

AdjointValue is a symbolic representation for derivatives values enabling efficient differentiation by avoiding zero materialization.

LinearMapInfo

LinearMapInfo contains information about linear map structs and branching trace enums, which are auxiliary data structures created by the differentiation transform.

These data structures are constructed in JVP/VJP functions and consumed in differential/pullback functions.

Transformation example

Original function:

@differentiable
func foo(_ x: Float) -> Float {
    return sin(x) * cos(x)
}

// Simplified SIL pseudocode.
sil [differentiable source 0 wrt 0] @foo : $(Float) -> Float {
bb0(%x):
  %y1 = apply @sin(%x)
  %y2 = apply @cos(%x)
  %y3 = apply @mul(%y1, %y2)
  return %y3
}

Generated VJP and pullback functions (high-level pseudocode):

// High-level pseudocode, using closure syntax.
// VJP: replaces all function applications with VJP applications.
sil @vjp_foo : $(Float) -> (Float, (Float) -> Float) {
bb0(%x):
  (%y1, %pb_sin) = apply @vjp_sin(%x)
  (%y2, %pb_cos) = apply @vjp_cos(%x)
  (%y3, %pb_mul) = apply @vjp_mul(%y1, %y2)
  // Return tuple of original result and pullback.
  return (%y3, { %dy3 in
    // All "adjoint values" in the pullback are zero-initialized.
    // %dx = 0, %dy1 = 0, %dy2 = 0
    (%dy1, %dy2) += %pb_mul(%dy3)
    (%dx) += %pb_cos(%dy2)
    (%dx) += %pb_sin(%dy1)
    return %dx
  })
}

Generated VJP and pullback functions (lower-level pseudocode):

// More accurate, closure-free pseudocode.
// Pullback functions are top-level SIL functions.

// Struct containing pullback functions.
// Partially-applied to `@pb_foo` in `@vjp_foo`.
struct foo_bb0_PB_src_0_wrt_0 {
  var pb_sin: (Float) -> Float
  var pb_cos: (Float) -> Float
  var pb_mul: (Float) -> (Float, Float)
}

// VJP: replaces all function applications with VJP applications.
sil @vjp_foo : $(Float) -> (Float, (Float) -> Float) {
bb0(%x):
  (%y1, %pb_sin) = apply @vjp_sin(%x)
  (%y2, %pb_cos) = apply @vjp_cos(%x)
  (%y3, %pb_mul) = apply @vjp_mul(%y1, %y2)
  // Partially-apply to get a pullback.
  %pb_struct = struct $foo_bb0_PB_src_0_wrt_0 (%pb_sin, %pb_cos, %pb_mul)
  %pb = partial_apply @pb_foo(%pb_struct)
  // Return tuple of original result and pullback.
  %result = tuple (%y3, %pb)
  return %result
}

// Pullback: apply pullbacks to adjoint values.
sil @pb_foo : $(Float, foo_bb0_PB_src_0_wrt_0) -> (Float) {
bb0(%dy3, %pb_struct):
  // All "adjoint values" in the pullback are zero-initialized.
  // %dx = 0, %dy1 = 0, %dy2 = 0
  %pb_mul = struct_extract %pb_struct, #pb_mul
  (%dy1, %dy2) += %pb_mul(%dy3)
  %pb_cos = struct_extract %pb_struct, #pb_cos
  (%dx) += %pb_cos(%dy2)
  %pb_sin = struct_extract %pb_struct, #pb_sin
  (%dx) += %pb_sin(%dy1)
  return %dx
}


Add simple VJP and pullback generation test.
TF-1232 (AutoDiff-generated declaration mangling) is the latest blocker for end-to-end tests.

More tests will be upstreamed later. It's hard to test things end-to-end until more upstreaming is done, so I'd like to make incremental upstreaming progress.

Differentiable activity analysis is a dataflow analysis which marks values in
a function as varied, useful, or active (both varied and useful).

Only active values need a derivative.
Add `AdjointValue`: a symbolic representation for adjoint values enabling
efficient differentiation by avoiding zero materialization.
`LinearMapInfo` contains information about linear map structs and branching
trace enums, which are auxiliary data structures created by the differentiation
transform.

These data structures are constructed in JVP/VJP functions and consumed in
differential/pullback functions.
`VJPEmitter` is a cloner that emits VJP functions. It implements reverse-mode
automatic differentiation, along with `PullbackEmitter`.

`VJPEmitter` clones an original function, replacing function applications with
VJP function applications. In VJP functions, each basic block takes a pullback
struct (containing callee pullbacks) and produces a predecessor enum: these data
structures are consumed by pullback functions.
`PullbackEmitter` is a visitor that emits pullback functions. It implements
reverse-mode automatic differentiation, along with `VJPEmitter`.

Pullback functions take derivatives with respect to outputs and return
derivatives with respect to inputs. Every active value/address in an original
function has a corresponding adjoint value/buffer in the pullback function.

Pullback functions consume pullback structs and predecessor enums constructed
by VJP functions.
@dan-zheng dan-zheng requested review from rxwei and marcrasi April 6, 2020 04:31
@dan-zheng
Copy link
Contributor Author

@gottesmm: would you like to help review this patch, as a SILOptimizer code owner?


include/swift/SILOptimizer/Utils/Differentiation/{Common,Thunk}.h contain some utilities that:

…larations.

IRGenDebugInfo crash due to lack of proper mangling for AutoDiff-generated
declarations: linear map structs and branching trace enums.
@dan-zheng dan-zheng force-pushed the differentiation-transform branch from 91197c1 to 52374bf Compare April 6, 2020 04:39
@dan-zheng
Copy link
Contributor Author

@swift-ci Please test

@swift-ci
Copy link
Contributor

swift-ci commented Apr 6, 2020

Build failed
Swift Test Linux Platform
Git Sha - 91197c1515f3bac11f94427b5c800d019ce83b13

@swift-ci
Copy link
Contributor

swift-ci commented Apr 6, 2020

Build failed
Swift Test OS X Platform
Git Sha - 91197c1515f3bac11f94427b5c800d019ce83b13

@dan-zheng
Copy link
Contributor Author

Merging to unblock progress. Happy to address any feedback later!

@dan-zheng dan-zheng merged commit bb0aa1c into swiftlang:master Apr 6, 2020
@dan-zheng dan-zheng deleted the differentiation-transform branch April 6, 2020 07:16
dan-zheng added a commit to dan-zheng/swift that referenced this pull request Apr 7, 2020
`SourceFile::addVisibleDecl` is an unnecessary API.
It was upstreamed in swiftlang#30821.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants