Skip to content

[AutoDiff upstream] Add differentiability witness SILGen. #30545

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 21, 2020

Conversation

dan-zheng
Copy link
Contributor

In the AST, @differentiable and @derivative attributes represent function differentiability.
In SIL, differentiability witnesses represent function differentiability.

Generate SIL differentiability witnesses from @differentiable and @derivative attributes.


Add SILGen utilities for:

  • Emiting differentiability witnesses.
  • Creating derivative function thunks, which are used as entries in differentiability witnesses.

When users register a custom derivative function, it is necessary to create a thunk with the expected derivative type computed from the original function's type. This is important for consistent typing and consistent differentiability witness entry mangling.


Example:

func id<T>(_ x: T) -> T { x }

@derivative(of: id)
func derivative<T: Differentiable>(_ x: __owned T) -> (
  value: T, pullback: (T.TangentVector) -> T.TangentVector
) { fatalError() }

// SIL details:
//
//   Original function type: $@convention(thin) <T> (@in_guaranteed T) -> @out T
//   Actual derivative type: $@convention(thin) <T where T : Differentiable> (@in T) -> (@out T, @owned @callee_guaranteed (@in_guaranteed T.TangentVector) -> @out T.TangentVector)
// Expected derivative type: $@convention(thin) <T where T : Differentiable> (@in_guaranteed T) -> (@out T, @owned @callee_guaranteed (@in_guaranteed T.TangentVector) -> @out T.TangentVector)
//
// In this case, the derivative thunk performs reabstraction.
// For original functions that are methods, derivative thunks also perform parameter reordering.
// See `SILGenModule::getOrCreateCustomDerivativeThunk` documentation for details.

SILGen emits:

// differentiability witness for id<A>(_:)
sil_differentiability_witness hidden [parameters 0] [results 0] <T where T : Differentiable> @$s3foo2idyxxlF : $@convent
ion(thin) <T> (@in_guaranteed T) -> @out T {
  vjp: @AD__$s3foo2idyxxlF__vjp_src_0_wrt_0_s14DifferentiableRzl : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiabl
e> (@in_guaranteed τ_0_0) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector) -> @out τ_0_0.T
angentVector)
}

// derivative thunk for id<A>(_:) [parameters 0] [results 0]
sil hidden [thunk] [always_inline] [ossa] @AD__$s3foo2idyxxlF__vjp_src_0_wrt_0_s14DifferentiableRzl : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed τ_0_0) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector) -> @out τ_0_0.TangentVector) {
// %0                                             // user: %5
// %1                                             // user: %4
bb0(%0 : $*τ_0_0, %1 : $*τ_0_0):
  // function_ref derivative<A>(_:)
  %2 = function_ref @$s3foo10derivativeyx5value_13TangentVectorQzAEc8pullbacktxns14DifferentiableRzlF : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@in τ_0_0) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector) -> @out τ_0_0.TangentVector) // user: %5
  %3 = alloc_stack $τ_0_0                        // users: %6, %5, %4
  copy_addr %1 to [initialization] %3 : $*τ_0_0  // id: %4
  %5 = apply %2<τ_0_0>(%0, %3) : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@in τ_0_0) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector) -> @out τ_0_0.TangentVector) // user: %7
  dealloc_stack %3 : $*τ_0_0                     // id: %6
  return %5 : $@callee_guaranteed (@in_guaranteed τ_0_0.TangentVector) -> @out τ_0_0.TangentVector // id: %7
} // end sil function 'AD__$s3foo2idyxxlF__vjp_src_0_wrt_0_s14DifferentiableRzl'

Resolves TF-1138.
Incremental improvements are tracked via TODO comments.

Generate SIL differentiability witnesses from `@differentiable` and
`@derivative` declaration attributes.

Add SILGen utilities for:
- Emiting differentiability witnesses.
- Creating derivative function thunks, which are used as entries in
  differentiability witnesses.

When users register a custom derivative function, it is necessary to create a
thunk with the expected derivative type computed from the original function's
type. This is important for consistent typing and consistent differentiability
witness entry mangling.

See `SILGenModule::getOrCreateCustomDerivativeThunk` documentation for details.

Resolves TF-1138.
@dan-zheng dan-zheng requested review from marcrasi and rxwei March 20, 2020 23:12
@dan-zheng
Copy link
Contributor Author

@jckarter @rjmccall: would you like to review this patch, as SIL/SILGen code owners?

I think the most controversial part of this patch is adding custom SILGen thunking logic.

We tried to reuse existing SILGen infrastructure for reabstraction thunking, but couldn’t make it work because derivative function thunks currently also perform parameter reordering to reduce complexity for the differentiation transform. (documentation)

We could undo this parameter reordering (#24775) and investigate reusing SILGen reabstraction thunking infrastructure, but I’d like to do so as an incremental improvement if that's okay!

Comment on lines +1624 to +1627
if (innerASTTy->hasArchetype())
innerASTTy = innerASTTy->mapTypeOutOfContext()->getCanonicalType();
if (outerASTTy->hasArchetype())
outerASTTy = outerASTTy->mapTypeOutOfContext()->getCanonicalType();
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: this change is necessary to avoid a crash in SILGenModule::getOrCreateCustomDerivativeThunk when calling forwardFunctionArguments:

import _Differentiation

func foo<T: Differentiable>(_ x: T) -> T { x }

@derivative(of: foo)
func vjpFoo<T: Differentiable>(_ x: T) -> (value: T, pullback: (T.TangentVector) -> T.TangentVector) {
  fatalError()
}
$ swiftc -Xfrontend -enable-experimental-differentiable-programming -emit-silgen okay.swift
innerASTTy:
(primary_archetype_type address=0x7f87b4116f90 conforms_to=_Differentiation.(file).Differentiable name=τ_0_0

  (nested_type=TangentVector <<unresolved>>))

outerASTTy:
(generic_type_param_type depth=0 index=0)

unhandled reabstraction type mismatch
UNREACHABLE executed at /Users/danielzheng/swift-master/swift/lib/SILGen/SILGenPoly.cpp:1663!
Stack dump:
0.	Program arguments: /Users/danielzheng/swift-master/build/Ninja-ReleaseAssert/swift-macosx-x86_64/bin/swift -frontend -emit-silgen -primary-file okay.swift -target x86_64-apple-darwin18.7.0 -enable-objc-interop -sdk /Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX10.15.sdk -color-diagnostics -enable-experimental-differentiable-programming -module-name okay -o -
1.	Swift version 5.3-dev (LLVM 97ca3af693, Swift 59a9fb3b55)
2.	While evaluating request SILGenSourceFileRequest(SIL Generation for file "okay.swift")
3.	While emitting SIL for 'vjpFoo(_:)' (at okay.swift:6:1)
4.	While silgen emitFunction SIL function "@$s4okay6vjpFooyx5value_13TangentVectorQzAEc8pullbacktx16_Differentiation14DifferentiableRzlF".
 for 'vjpFoo(_:)' (at okay.swift:6:1)
0  swift                    0x00000001060818e5 llvm::sys::PrintStackTrace(llvm::raw_ostream&) + 37
1  swift                    0x0000000106080898 llvm::sys::RunSignalHandlers() + 248
2  swift                    0x0000000106081edc SignalHandler(int) + 268
3  libsystem_platform.dylib 0x00007fff6deebb5d _sigtramp + 29
4  swift                    0x00000001075243c2 cmark_strbuf_put.cold.1 + 317266
5  libsystem_c.dylib        0x00007fff6dda56a6 abort + 127
6  swift                    0x000000010749a9ae llvm::llvm_unreachable_internal(char const*, char const*, unsigned int) + 462
7  swift                    0x0000000102041d35 applyTrivialConversions(swift::Lowering::SILGenFunction&, swift::SILLocation, swift::Lowering::ManagedValue, swift::SILType) + 821
8  swift                    0x0000000102038836 forwardFunctionArguments(swift::Lowering::SILGenFunction&, swift::SILLocation, swift::CanTypeWrapper<swift::SILFunctionType>, llvm::ArrayRef<swift::Lowering::ManagedValue>, llvm::SmallVectorImpl<swift::SILValue>&) + 390
9  swift                    0x0000000102037a00 swift::Lowering::SILGenModule::getOrCreateCustomDerivativeThunk(swift::SILFunction*, swift::SILFunction*, swift::AutoDiffConfig const&, swift::AutoDiffDerivativeFunctionKind) + 2848
10 swift                    0x0000000101f779de swift::Lowering::SILGenModule::emitDifferentiabilityWitness(swift::AbstractFunctionDecl*, swift::SILFunction*, swift::AutoDiffConfig const&, swift::SILFunction*, swift::SILFunction*, swift::DeclAttribute const*) + 526
11 swift                    0x0000000101f7776f swift::Lowering::SILGenModule::emitDifferentiabilityWitnessesForFunction(swift::SILDeclRef, swift::SILFunction*)::$_1::operator()(swift::DeclAttributes&) const + 1103
12 swift                    0x0000000101f7730c swift::Lowering::SILGenModule::emitDifferentiabilityWitnessesForFunction(swift::SILDeclRef, swift::SILFunction*) + 172
13 swift                    0x0000000101f77245 swift::Lowering::SILGenModule::postEmitFunction(swift::SILDeclRef, swift::SILFunction*) + 245
14 swift                    0x0000000101f7e4c8 swift::Lowering::SILGenModule::emitFunction(swift::FuncDecl*)::$_3::operator()(swift::SILFunction*) const + 264
Failing Tests (2):
    Swift(macosx-x86_64) :: AutoDiff/SILGen/sil_differentiability_witness_silgen.swift
    Swift(macosx-x86_64) :: AutoDiff/Serialization/derivative_attr.swift

@jckarter: I wonder if this change is reasonable? It feels bad.

I looked a bit into changing other code (SILFunctionType::getAutoDiffDerivativeFunctionType) so that this change isn't necessary but couldn't figure it out. Maybe others have ideas.

@dan-zheng
Copy link
Contributor Author

@swift-ci Please smoke test

@dan-zheng
Copy link
Contributor Author

Merging to unblock progress. Happy to address feedback later!

@dan-zheng dan-zheng merged commit 24445dd into swiftlang:master Mar 21, 2020
@dan-zheng dan-zheng deleted the autodiff-upstream-sil branch March 21, 2020 09:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants