-
Notifications
You must be signed in to change notification settings - Fork 10.5k
[AutoDiff upstream] Add differentiability witness SILGen. #30545
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Generate SIL differentiability witnesses from `@differentiable` and `@derivative` declaration attributes. Add SILGen utilities for: - Emiting differentiability witnesses. - Creating derivative function thunks, which are used as entries in differentiability witnesses. When users register a custom derivative function, it is necessary to create a thunk with the expected derivative type computed from the original function's type. This is important for consistent typing and consistent differentiability witness entry mangling. See `SILGenModule::getOrCreateCustomDerivativeThunk` documentation for details. Resolves TF-1138.
@jckarter @rjmccall: would you like to review this patch, as SIL/SILGen code owners? I think the most controversial part of this patch is adding custom SILGen thunking logic. We tried to reuse existing SILGen infrastructure for reabstraction thunking, but couldn’t make it work because derivative function thunks currently also perform parameter reordering to reduce complexity for the differentiation transform. (documentation) We could undo this parameter reordering (#24775) and investigate reusing SILGen reabstraction thunking infrastructure, but I’d like to do so as an incremental improvement if that's okay! |
if (innerASTTy->hasArchetype()) | ||
innerASTTy = innerASTTy->mapTypeOutOfContext()->getCanonicalType(); | ||
if (outerASTTy->hasArchetype()) | ||
outerASTTy = outerASTTy->mapTypeOutOfContext()->getCanonicalType(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: this change is necessary to avoid a crash in SILGenModule::getOrCreateCustomDerivativeThunk
when calling forwardFunctionArguments
:
import _Differentiation
func foo<T: Differentiable>(_ x: T) -> T { x }
@derivative(of: foo)
func vjpFoo<T: Differentiable>(_ x: T) -> (value: T, pullback: (T.TangentVector) -> T.TangentVector) {
fatalError()
}
$ swiftc -Xfrontend -enable-experimental-differentiable-programming -emit-silgen okay.swift
innerASTTy:
(primary_archetype_type address=0x7f87b4116f90 conforms_to=_Differentiation.(file).Differentiable name=τ_0_0
(nested_type=TangentVector <<unresolved>>))
outerASTTy:
(generic_type_param_type depth=0 index=0)
unhandled reabstraction type mismatch
UNREACHABLE executed at /Users/danielzheng/swift-master/swift/lib/SILGen/SILGenPoly.cpp:1663!
Stack dump:
0. Program arguments: /Users/danielzheng/swift-master/build/Ninja-ReleaseAssert/swift-macosx-x86_64/bin/swift -frontend -emit-silgen -primary-file okay.swift -target x86_64-apple-darwin18.7.0 -enable-objc-interop -sdk /Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX10.15.sdk -color-diagnostics -enable-experimental-differentiable-programming -module-name okay -o -
1. Swift version 5.3-dev (LLVM 97ca3af693, Swift 59a9fb3b55)
2. While evaluating request SILGenSourceFileRequest(SIL Generation for file "okay.swift")
3. While emitting SIL for 'vjpFoo(_:)' (at okay.swift:6:1)
4. While silgen emitFunction SIL function "@$s4okay6vjpFooyx5value_13TangentVectorQzAEc8pullbacktx16_Differentiation14DifferentiableRzlF".
for 'vjpFoo(_:)' (at okay.swift:6:1)
0 swift 0x00000001060818e5 llvm::sys::PrintStackTrace(llvm::raw_ostream&) + 37
1 swift 0x0000000106080898 llvm::sys::RunSignalHandlers() + 248
2 swift 0x0000000106081edc SignalHandler(int) + 268
3 libsystem_platform.dylib 0x00007fff6deebb5d _sigtramp + 29
4 swift 0x00000001075243c2 cmark_strbuf_put.cold.1 + 317266
5 libsystem_c.dylib 0x00007fff6dda56a6 abort + 127
6 swift 0x000000010749a9ae llvm::llvm_unreachable_internal(char const*, char const*, unsigned int) + 462
7 swift 0x0000000102041d35 applyTrivialConversions(swift::Lowering::SILGenFunction&, swift::SILLocation, swift::Lowering::ManagedValue, swift::SILType) + 821
8 swift 0x0000000102038836 forwardFunctionArguments(swift::Lowering::SILGenFunction&, swift::SILLocation, swift::CanTypeWrapper<swift::SILFunctionType>, llvm::ArrayRef<swift::Lowering::ManagedValue>, llvm::SmallVectorImpl<swift::SILValue>&) + 390
9 swift 0x0000000102037a00 swift::Lowering::SILGenModule::getOrCreateCustomDerivativeThunk(swift::SILFunction*, swift::SILFunction*, swift::AutoDiffConfig const&, swift::AutoDiffDerivativeFunctionKind) + 2848
10 swift 0x0000000101f779de swift::Lowering::SILGenModule::emitDifferentiabilityWitness(swift::AbstractFunctionDecl*, swift::SILFunction*, swift::AutoDiffConfig const&, swift::SILFunction*, swift::SILFunction*, swift::DeclAttribute const*) + 526
11 swift 0x0000000101f7776f swift::Lowering::SILGenModule::emitDifferentiabilityWitnessesForFunction(swift::SILDeclRef, swift::SILFunction*)::$_1::operator()(swift::DeclAttributes&) const + 1103
12 swift 0x0000000101f7730c swift::Lowering::SILGenModule::emitDifferentiabilityWitnessesForFunction(swift::SILDeclRef, swift::SILFunction*) + 172
13 swift 0x0000000101f77245 swift::Lowering::SILGenModule::postEmitFunction(swift::SILDeclRef, swift::SILFunction*) + 245
14 swift 0x0000000101f7e4c8 swift::Lowering::SILGenModule::emitFunction(swift::FuncDecl*)::$_3::operator()(swift::SILFunction*) const + 264
Failing Tests (2):
Swift(macosx-x86_64) :: AutoDiff/SILGen/sil_differentiability_witness_silgen.swift
Swift(macosx-x86_64) :: AutoDiff/Serialization/derivative_attr.swift
@jckarter: I wonder if this change is reasonable? It feels bad.
I looked a bit into changing other code (SILFunctionType::getAutoDiffDerivativeFunctionType
) so that this change isn't necessary but couldn't figure it out. Maybe others have ideas.
@swift-ci Please smoke test |
Merging to unblock progress. Happy to address feedback later! |
In the AST,
@differentiable
and@derivative
attributes represent function differentiability.In SIL, differentiability witnesses represent function differentiability.
Generate SIL differentiability witnesses from
@differentiable
and@derivative
attributes.Add SILGen utilities for:
When users register a custom derivative function, it is necessary to create a thunk with the expected derivative type computed from the original function's type. This is important for consistent typing and consistent differentiability witness entry mangling.
Example:
SILGen emits:
Resolves TF-1138.
Incremental improvements are tracked via
TODO
comments.