-
Notifications
You must be signed in to change notification settings - Fork 10.5k
[AutoDiff] WIP: Use owned callee convention for linear maps. #34935
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
Note: There are two places where reabstracting back to
|
…(from:)'. Remove unused APIs `differentiableFunction(from:)` and `linearFunction(from:)`. They were never official APIs, are not included in the [initial proposal](https://github.com/rxwei/swift-evolution/blob/autodiff/proposals/0000-differentiable-programming.md#make-a-function-differentiable-using-derivative), and are unused by existing supported client libraries (SwiftFusion and S4TF). Most importantly, they block crucial optimizations on linear map closures (swiftlang#34935) and would need nontrivial work in SILGen to support.
513fc9a
to
5b3a3e3
Compare
How can I complete this PR? I'm looking at completing SR-15580 as well as gaining more knowledge about the compiler in general to prepare for debugging autodiff. |
@CodaFi I know this is a stretch, but do you have the capacity to help me out here? It seems like I'll have to wait a long time before anybody in the old S4TF team does. |
I'm not sure why this PR is a work in progress. Perhaps some tests didn't pass? Try rebasing it and we can get CI running to see what its current state is. |
@CodaFi I'm having trouble figuring out how to rebase. I tried making a new PR from this branch to main and there are conflicts. However, instead of letting me make a new PR, it only gives me the link to this PR. I'm probably doing something wrong because I have limited experience with Git. |
Thanks for the ping. I'm giving this a shot this weekend. |
Switch to `@callee_owned` callee convention for all linear map functions (differentials and pullbacks) returned from derivative functions. This reduces a half of reference counting operations in compiler-generated derivatives, and enables child linear maps that are called in linear maps to be destroyed right after the call. Resolves rdar://71892494.
5b3a3e3
to
e8ebcc5
Compare
Just rebased on top of main and most things are working. For anyone interested in picking up this work, the reason this was still WIP was because there are some unexpected memory leaks causing validation tests to fail. Swift(macosx-x86_64) :: AutoDiff/validation-test/differentiable_protocol_requirements.swift
Swift(macosx-x86_64) :: AutoDiff/validation-test/existential.swift
Swift(macosx-x86_64) :: AutoDiff/validation-test/forward_mode_simple.swift For example: [ RUN ] ProtocolRequirementDifferentiation.func
stdout>>> check failed at /Volumes/Media/Development/Swift/swift-source/swift/test/AutoDiff/validation-test/differentiable_protocol_requirements.swift, line 55
stdout>>> expected: 0 (of type Swift.Int)
stdout>>> actual: 12 (of type Swift.Int)
stdout>>> Leaks detected: 12
[ FAIL ] ProtocolRequirementDifferentiation.func
[ RUN ] ProtocolRequirementDifferentiation.constructor, accessor, subscript
stdout>>> check failed at /Volumes/Media/Development/Swift/swift-source/swift/test/AutoDiff/validation-test/differentiable_protocol_requirements.swift, line 117
stdout>>> expected: 0 (of type Swift.Int)
stdout>>> actual: 1 (of type Swift.Int)
stdout>>> Leaks detected: 1
[ FAIL ] ProtocolRequirementDifferentiation.constructor, accessor, subscript
ProtocolRequirementDifferentiation: Some tests failed, aborting
UXPASS: []
FAIL: ["func", "constructor, accessor, subscript"]
SKIP: [] TODO:
|
cc @BradLarson |
Thanks for investigating this! I don’t know if my current experience with the compiler is sufficient to fix the memory leak, though. Maybe after implementing the cross-import overlay. I’d like to take on this bug myself in the future, because it’s a great stepping stone for future debugging experience. It’s small, specific and tractable. Since changing the calling convention is just an optimization, would it be okay if it got delayed because I waited to fix it? |
@philipturner This is really not a "small, specific and tractable" task for Swift compiler beginners. It requires good understanding of SIL and the differentiation transform, because the leak is likely due to incorrectly generated code within SILGen and differentiation transform. Also, it's not "just an optimization", as it's a major ABI change. I don't have the bandwidth to tackle this in the near future, and was hoping that @BradLarson (and team) could take a look when they need to work on further optimizations that are blocked by this. @BradLarson should also be able to suggest some starter tasks for you. |
Okay, thanks for the advice! I think the cross-import overlay implementation might go quickly, so I might need starter tasks soon. In the meantime, I'll look for ways to get familiar with the Swift compiler on my own. |
The SIL for generated pullback looks a bit suspicious to me:
Note the following:
So it looks like we're changing calling convention here from |
These pullbacks were actually |
Yes, makes sense. Ok, in terms of leaked values in existential.swift, there are two allocs of "3" from:
but we release only one |
Weird... If there's nothing I missed from applying the pullback bitcasting workaround, I'd expect any leaks to trigger an OSSA verification failure. Something apparently slipped through. We can check if this |
Things might be a bit more interesting. The corresponding |
So, yes. On |
So, the value is captured in the pullback of
It seems like this pullback is never released. |
Could this be an IRGen issue when its emitting a non- It's safe to assume that this library-defined pullback is emitted correctly since its closure lowering is not changed by this PR. Library-defined pullbacks are being captured by parent pullbacks generated by the differentiation transform via non- |
The partial apply is |
Right, this particular |
Yes, it's one level up:
On |
What about the generated partial application forwarders in LLVM IR? |
@rxwei I rebased this PR into However, the bitcasts are now from / to ABI-incompatible types, e.g. from Returning reabstraction thunks back resolved the issues (though leaks remained, but this is a separate issue) :) |
That's a bit surprising though, as we should only be changing the callee convention, not the result abstraction. |
Yes. And there were nothing like this in my previous rebase... |
I'd recommend debugging this with the previous commit. My understanding from @gottesmm was that |
@rxwei Yes, this is the plan. I'm just documenting some things that might help us :) |
@rxwei Ok, the problem exists in main as well :) What happens: for captured stuff we're having the following pullback type: On We started to handle captured arguments only very recently, this is why this issue did not occur previously. |
There are multiple issues here actually. Few of them are related directly to this PR and I fixed them locally (some are just typos and others are related to the recent changes in autodiff code). The root cause is somewhere around reabstraction thunks from I managed to significantly reduce one of testcases stripping layers and layers of abstractions :) Here is the piece of optimized LLVM IR showing the issue:
Here %16 is a differential call. Note that |
Ok, so, I'm a bit confused. This is what we're having here: a reabstraction thunk to convert from This is how everything is organized: ...
// function_ref thunk for @callee_guaranteed (@unowned Float) -> (@unowned Float, @owned @escaping @callee_owned (@unowned Float) -> (@unowned Float))
%20 = function_ref @$sS4fIexyd_Igydo_S2fxq_r0_lyS2fIsegnr_Iegnro_TR : $@convention(thin) (@in_guaranteed Float, @noescape @callee_guaranteed (Float) -> (Float, @owned @callee_owned (Float) -> Float)) -> (@out Float, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <Float, Float>), loc "forward_mode_simple2.swift":52:64, scope 33 // user: %21
%21 = partial_apply [callee_guaranteed] %20(%19) : $@convention(thin) (@in_guaranteed Float, @noescape @callee_guaranteed (Float) -> (Float, @owned @callee_owned (Float) -> Float)) -> (@out Float, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <Float, Float>), loc "forward_mode_simple2.swift":52:64, scope 33 // user: %22
%22 = convert_function %21 : $@callee_guaranteed (@in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <Float, Float>) to $@callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2, τ_0_3> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <τ_0_2, τ_0_3>) for <Float, Float, Float, Float>, loc "forward_mode_simple2.swift":52:64, scope 33 // users: %54, %23
...
strong_release %22 : $@callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2, τ_0_3> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <τ_0_2, τ_0_3>) for <Float, Float, Float, Float>, loc "<compiler-generated>":0:0, scope 33 // id: %54
... and thunk itself is: // thunk for @callee_guaranteed (@unowned Float) -> (@unowned Float)
sil shared [transparent] [reabstraction_thunk] @$sS2fIgyd_S2fIegnr_TR : $@convention(thin) (@in_guaranteed Float, @noescape @callee_guaranteed (Float) -> Float) -> @out Float {
// %0 // user: %5
// %1 // user: %3
// %2 // user: %4
bb0(%0 : $*Float, %1 : $*Float, %2 : $@noescape @callee_guaranteed (Float) -> Float):
%3 = load %1 : $*Float, loc "<compiler-generated>":0:0, scope 34 // user: %4
%4 = apply %2(%3) : $@noescape @callee_guaranteed (Float) -> Float, loc "<compiler-generated>":0:0, scope 34 // user: %5
store %4 to %0 : $*Float, loc "<compiler-generated>":0:0, scope 34 // id: %5
%6 = tuple (), loc "<compiler-generated>":0:0, scope 34 // user: %7
return %6 : $(), loc "<compiler-generated>":0:0, scope 34 // id: %7
} // end sil function '$sS2fIgyd_S2fIegnr_TR'
// thunk for @callee_guaranteed (@unowned Float) -> (@unowned Float, @owned @escaping @callee_owned (@unowned Float) -> (@unowned Float))
sil shared [transparent] [reabstraction_thunk] @$sS4fIexyd_Igydo_S2fxq_r0_lyS2fIsegnr_Iegnro_TR : $@convention(thin) (@in_guaranteed Float, @noescape @callee_guaranteed (Float) -> (Float, @owned @callee_owned (Float) -> Float)) -> (@out Float, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <Float, Float>) {
// %0 // user: %7
// %1 // user: %3
// %2 // user: %4
bb0(%0 : $*Float, %1 : $*Float, %2 : $@noescape @callee_guaranteed (Float) -> (Float, @owned @callee_owned (Float) -> Float)):
%3 = load %1 : $*Float, loc "<compiler-generated>":0:0, scope 35 // user: %4
%4 = apply %2(%3) : $@noescape @callee_guaranteed (Float) -> (Float, @owned @callee_owned (Float) -> Float), loc "<compiler-generated>":0:0, scope 35 // users: %6, %5
%5 = tuple_extract %4 : $(Float, @callee_owned (Float) -> Float), 0, loc "<compiler-generated>":0:0, scope 35 // user: %7
%6 = tuple_extract %4 : $(Float, @callee_owned (Float) -> Float), 1, loc "<compiler-generated>":0:0, scope 35 // user: %9
store %5 to %0 : $*Float, loc "<compiler-generated>":0:0, scope 35 // id: %7
// function_ref thunk for @escaping @callee_owned (@unowned Float) -> (@unowned Float)
%8 = function_ref @$sS2fIexyd_S2fIegnr_TR : $@convention(thin) (@in_guaranteed Float, @owned @callee_owned (Float) -> Float) -> @out Float, loc "<compiler-generated>":0:0, scope 35 // user: %9
%9 = partial_apply [callee_guaranteed] %8(%6) : $@convention(thin) (@in_guaranteed Float, @owned @callee_owned (Float) -> Float) -> @out Float, loc "<compiler-generated>":0:0, scope 35 // user: %10
%10 = convert_function %9 : $@callee_guaranteed (@in_guaranteed Float) -> @out Float to $@callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <Float, Float>, loc "<compiler-generated>":0:0, scope 35 // user: %11
return %10 : $@callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <Float, Float>, loc "<compiler-generated>":0:0, scope 35 // id: %11
} // end sil function '$sS4fIexyd_Igydo_S2fxq_r0_lyS2fIsegnr_Iegnro_TR' The conversion is essentially done via a closure. The outer closure (of the thunk) is correctly released. However, where and how the inner one ( define linkonce_odr hidden swiftcc { i8*, %swift.refcounted* } @"$sS4fIexyd_Igydo_S2fxq_r0_lyS2fIsegnr_Iegnro_TR"(%TSf* noalias nocapture %0, %TSf* noalias nocapture dereferenceable(4) %1, i8* %2, %swift.opaque* %3) #0 !dbg !314 {
entry:
%._value = getelementptr inbounds %TSf, %TSf* %1, i32 0, i32 0, !dbg !319
%4 = load float, float* %._value, align 4, !dbg !319
%5 = bitcast i8* %2 to { float, i8*, %swift.refcounted* } (float, %swift.refcounted*)*, !dbg !319
%6 = bitcast %swift.opaque* %3 to %swift.refcounted*, !dbg !319
%7 = call swiftcc { float, i8*, %swift.refcounted* } %5(float %4, %swift.refcounted* swiftself %6) #17, !dbg !319
%8 = extractvalue { float, i8*, %swift.refcounted* } %7, 0, !dbg !319
%9 = extractvalue { float, i8*, %swift.refcounted* } %7, 1, !dbg !319
%10 = extractvalue { float, i8*, %swift.refcounted* } %7, 2, !dbg !319
%._value1 = getelementptr inbounds %TSf, %TSf* %0, i32 0, i32 0, !dbg !319
store float %8, float* %._value1, align 4, !dbg !319
%11 = call noalias %swift.refcounted* @swift_allocObject(%swift.type* getelementptr inbounds (%swift.full_boxmetadata, %swift.full_boxmetadata* @metadata.50, i32 0, i32 2), i64 32, i64 7) #5, !dbg !319
%12 = bitcast %swift.refcounted* %11 to <{ %swift.refcounted, %swift.function }>*, !dbg !319
%13 = getelementptr inbounds <{ %swift.refcounted, %swift.function }>, <{ %swift.refcounted, %swift.function }>* %12, i32 0, i32 1, !dbg !319
%.fn = getelementptr inbounds %swift.function, %swift.function* %13, i32 0, i32 0, !dbg !319
store i8* %9, i8** %.fn, align 8, !dbg !319
%.data = getelementptr inbounds %swift.function, %swift.function* %13, i32 0, i32 1, !dbg !319
store %swift.refcounted* %10, %swift.refcounted** %.data, align 8, !dbg !319
%14 = insertvalue { i8*, %swift.refcounted* } { i8* bitcast (void (%TSf*, %TSf*, %swift.refcounted*)* @"$sS2fIexyd_S2fIegnr_TRTA" to i8*), %swift.refcounted* undef }, %swift.refcounted* %11, 1, !dbg !319
ret { i8*, %swift.refcounted* } %14, !dbg !319
} So, here the fresh context object ( Tagging @gottesmm for some guidance as the reabstraction thunk code is not autodiff-specific here :) |
@compnerd @gottesmm Do you happen to have some ideas what might be wrong with reabstractions thunks in #34935 (comment) ? Thanks! |
@rxwei Wanted to resurrect this. Will you please remind why switching from structs to tuples is the crucial here? |
AST function types don't support the After switching to tuple, you can now use |
Well, I'm afraid we'd still need to use it. As tuples also use AST function types. Otherwise various optimization passes (e.g. specializer) do assert – they cannot, for example, specialize lowered function types... |
Now that I try to remember, by "tuple" what I really meant was for linear map struct elements to be directly added to the VJP function's result list and to the pullback function's parameter list, i.e. no nested tuples. That way I don't think there's supposed to be any attempts to reconstruct AST types. |
Ah, so we'd essentially unwrap the current tuples, right? |
Yes. It'd be good to confirm the max parameter limit before committing to this, since our tuple size basically grows linearly with respect to the number of instructions (more specifically, |
Right. However, we very complex cases we can fallback to the wrapped tuple implementation. Essentially we know the size of tuple, so we can decide whether we'd want to unwrap it or not. |
Tuples have a limit too. I forgot the number |
32 bits are used to stored the # of elements. I think we're ok here :) # of function parameters are 32 bits as well. Results use 16 bits for the count. |
So, I'm afraid the things are little bit more complicated:
Also, even if we'd introduce reabstraction, the leaks do no go away, so I'm afraid |
Switch to
@callee_owned
callee convention for all linear map functions (differentials and pullbacks) returned from derivative functions. This reduces a half of reference counting operations in compiler-generated derivatives, and enables child linear maps that are called in linear maps to be destroyed right after the call.Background: Before this patch, linear map functions took an
@owned
context and had@callee_guaranteed
callee convention. It was a suboptimal design because@owned
because we want to consume the context as early as possible, but doing so in combination with@callee_guaranteed
convention leads to an unnecessary pair of retain (in the partial application forwarder) and release (in the caller). As a result, the entire context was kept alive until the entire outer pullback returns.Resolves rdar://71892494.