Skip to content

[AutoDiff] Revamp differentiation transform. #24845

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

dan-zheng
Copy link
Contributor

@dan-zheng dan-zheng commented May 16, 2019

  • Directly generate primal code in VJP functions.
    • Rename PrimalGenCloner to VJPEmitter.
    • This unblocks control-flow support.
      Primal data structures can be generated based on activity analysis.
  • Refactor differentiation transform to use iterative autodiff_function
    instruction worklist instead of a fixed DifferentiationTask worklist.
    • VJPEmitter::visitApply now creates autodiff_function and
      autodiff_function_extract instructions instead of directly emitting
      associated function references.
    • An iterative loop canonicalizes autodiff_function instructions,
      promoting them to @differentiable function-typed values.
  • Add autodiff_function_extract folding optimization.
    • Fold autodiff_function_extract users of autodiff_function instructions,
      directly replacing them with operands of the autodiff_function
      instruction.
    • If the autodiff_function instruction has only autodiff_function_extract
      users, delete the instruction itself after folding.
  • Remove unnecessary auxiliary data structures.
    • DifferentiationTask: replace with map from [differentiable] attributes
      to DifferentiationInvoker.
    • PrimalGen, AdjointGen: replace with autodiff_function iterative loop.

@dan-zheng dan-zheng added the tensorflow This is for "tensorflow" branch PRs. label May 16, 2019
@dan-zheng dan-zheng requested review from rxwei and marcrasi May 16, 2019 21:08
@dan-zheng dan-zheng force-pushed the differentiation-transform-revamp branch from e9d033a to 415a259 Compare May 16, 2019 21:08
@dan-zheng
Copy link
Contributor Author

@swift-ci Please test tensorflow

@@ -525,6 +525,15 @@ bool ReabstractionInfo::prepareAndCheck(ApplySite Apply, SILFunction *Callee,
return false;
}

// SWIFT_ENABLE_TENSORFLOW
// Disable specialization for instructions that are operands of
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is necessary to prevent post-differentiation SIL passes from generating invalid SIL.


In particular, this change disables specialization of operands to autodiff_function.
It may be possible to teach the specializer how to handle autodiff_function instructions. I don't know much work that would take, it seems difficult.

@@ -74,7 +74,15 @@ static bool foldInverseReabstractionThunks(PartialApplyInst *PAI,
SILInstruction *SILCombiner::visitPartialApplyInst(PartialApplyInst *PAI) {
// partial_apply without any substitutions or arguments is just a
// thin_to_thick_function.
if (!PAI->hasSubstitutions() && (PAI->getNumArguments() == 0)) {
if (!PAI->hasSubstitutions() && (PAI->getNumArguments() == 0) &&
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is necessary to prevent post-differentiation SIL passes from generating invalid SIL.


This change is rather innocuous and can be upstreamed to master. It just adds a previously unexercised check.

// redesign: currently, each original function + attribute pair is mapped
// only to one invoker.
/*
DifferentiationInvoker indirect(ai, this->original, attr);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: indirect differentiation invokers are currently not handled post-refactoring. I've disabled the test for now.

- Directly generate primal code in VJP functions.
  - Rename `PrimalGenCloner` to `VJPEmitter`.
  - This unblocks control-flow support.
    Primal data structures can be determined based on activity analysis.
- Refactor differentiation transform to use iterative `autodiff_function`
  instruction worklist instead of a fixed `DifferentiationTask` worklist.
  - `VJPEmitter::visitApply` now creates `autodiff_function` and
    `autodiff_function_extract` instructions instead of directly emitting
    associated function references.
  - An iterative loop canonicalizes `autodiff_function` instructions, filling
    in missing associated functions.
- Remove unnecessary auxiliary data structures.
  - `DifferentiationTask`: replace with map from original function and
    `[differentiable]` attribute to `DifferentiationInvoker`.
  - `PrimalGen`, `AdjointGen`: replace with `autodiff_function` iterative loop.
@dan-zheng dan-zheng force-pushed the differentiation-transform-revamp branch from 415a259 to 1efbac6 Compare May 16, 2019 22:19
@dan-zheng dan-zheng force-pushed the differentiation-transform-revamp branch from 1efbac6 to 9382eaf Compare May 16, 2019 22:32
@@ -6,14 +6,13 @@

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please add a note in this file to say what's unexpected/broken?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm actually not sure what's the expected behavior for these tests. I'll debug a bit to see what's going on.

At a glance, the changes seem unideal. The old diagnostics were more informative and appeared on the gradient user invocation.

--- a/test/AutoDiff/autodiff_indirect_diagnostics.swift
+++ b/test/AutoDiff/autodiff_indirect_diagnostics.swift
@@ -6,14 +6,13 @@

 // Test unmet generic requirements.

-// expected-error @+1 {{function is not differentiable}}
 @differentiable
-// expected-note @+1 {{when differentiating this function definition}}
 func generic<T: Differentiable & FloatingPoint>(_ x: T) -> T {
+  // expected-error @+2 {{expression is not differentiable}}
   // expected-note @+1 {{member is not differentiable because the corresponding protocol requirement is not '@differentiable'}}
   return x + 1
 }
-_ = gradient(at: 1.0, in: generic) // expected-error {{function is not differentiable}}
+_ = gradient(at: 1.0, in: generic)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't quite make sense of the diagnostic behavior changes yet.
We plan to change/fix diagnostic behavior soon, so it would be good to revisit these changes as part of that follow-up.

@dan-zheng
Copy link
Contributor Author

Ready for re-review. cc @rxwei
I'm investigating why diagnostics changed in test/AutoDiff/autodiff_indirect_diagnostics.swift.

Copy link
Contributor

@rxwei rxwei left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The patch LGTM overall. Huge congrats on getting this done!

There are two necessary improvements I think should be done before merging:

  1. The AdditiveArithmetic.zero emission logic in a thunk builder should be unified with what's already in AdjointEmitter.
  2. I suspect there's a performance regression caused by extra autodiff_function and autodiff_function_extract instructions being generated, which could pose a barrier for later optimizations. Folding trivial autodiff_function-autodiff_function_extract can be trivially done near the logic that canonicalizes autodiff_function instructions in the main loop.

Other than those, we should talk about fixing diagnostics for indirect calls soon.

- Create common function `emitZeroIntoBuffer`.
  - Shared by `AdjointEmitter::emitZeroIndirect` and `buildZeroArgument` lambda.
- Minor naming/style changes.
@dan-zheng
Copy link
Contributor Author

dan-zheng commented May 17, 2019

Big thanks for @rxwei for the thorough review and advice!

  1. Unification of zero emission logic is done in 646096e.
  2. autodiff_function_extract folding optimization is done in 064e61e.
    • As we discussed, a -Xllvm -differentiation-skip-folding-autodiff-function-extraction flag was added to disable folding for SIL testing purposes. Used in three tests.

@dan-zheng dan-zheng force-pushed the differentiation-transform-revamp branch from 064e61e to 1433de0 Compare May 17, 2019 09:11
Fold `autodiff_function_extract` users of `autodiff_function` instructions,
directly replacing them with operands of the `autodiff_function` instruction.

If the `autodiff_function` instruction has no non-`autodiff_function_extract`
users, delete the instruction itself after folding.

The `differentiation-skip-folding-autodiff-function-extraction` flag disables
folding for SIL testing purposes.
@dan-zheng dan-zheng force-pushed the differentiation-transform-revamp branch from 1433de0 to 2b43542 Compare May 17, 2019 09:13
@dan-zheng
Copy link
Contributor Author

@swift-ci Please test tensorflow

@dan-zheng dan-zheng merged commit 0a9181a into swiftlang:tensorflow May 17, 2019
@dan-zheng dan-zheng deleted the differentiation-transform-revamp branch May 17, 2019 09:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
tensorflow This is for "tensorflow" branch PRs.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants