Skip to content

[AutoDiff] WIP: Use owned callee convention for linear maps. #34935

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

rxwei
Copy link
Contributor

@rxwei rxwei commented Dec 3, 2020

Switch to @callee_owned callee convention for all linear map functions (differentials and pullbacks) returned from derivative functions. This reduces a half of reference counting operations in compiler-generated derivatives, and enables child linear maps that are called in linear maps to be destroyed right after the call.

Background: Before this patch, linear map functions took an @owned context and had @callee_guaranteed callee convention. It was a suboptimal design because

  1. All linear maps in AD-generated code are called exactly once. The context's parameter convention is @owned because we want to consume the context as early as possible, but doing so in combination with @callee_guaranteed convention leads to an unnecessary pair of retain (in the partial application forwarder) and release (in the caller). As a result, the entire context was kept alive until the entire outer pullback returns.
  2. Pullbacks' allocation and deallocation follow a strict stack discipline, so we really want to consume pullbacks as early as possible and not retain unused memory.

Resolves rdar://71892494.

Copy link
Contributor

@dan-zheng dan-zheng left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

@rxwei
Copy link
Contributor Author

rxwei commented Dec 3, 2020

Note: There are two places where reabstracting back to @callee_guaranteed is needed.

  • Builtin.applyDerivative* call sites. This is because the caller expects the result to have the formal lowered type. There isn't a way to control callee convention in the AST. This reabstraction should be optimized away.
  • Linear map struct fields. Similar to the case above, struct fields have AST types and thus cannot store @callee_owned closures directly. Two ways to eliminate the reabstraction are a) bitcasting and storing these closures as $(Builtin.RawPointer, Builtin.NativeObject) and b) replacing linear map structs with tuples.

rxwei added a commit to rxwei/swift that referenced this pull request Jan 12, 2021
…(from:)'.

Remove unused APIs `differentiableFunction(from:)` and `linearFunction(from:)`. They were never official APIs, are not included in the [initial proposal](https://github.com/rxwei/swift-evolution/blob/autodiff/proposals/0000-differentiable-programming.md#make-a-function-differentiable-using-derivative), and are unused by existing supported client libraries (SwiftFusion and S4TF). Most importantly, they block crucial optimizations on linear map closures (swiftlang#34935) and would need nontrivial work in SILGen to support.
@rxwei rxwei force-pushed the callee-owned-linear-map branch from 513fc9a to 5b3a3e3 Compare January 15, 2021 10:40
@philipturner
Copy link
Contributor

How can I complete this PR? I'm looking at completing SR-15580 as well as gaining more knowledge about the compiler in general to prepare for debugging autodiff.

@philipturner
Copy link
Contributor

@CodaFi I know this is a stretch, but do you have the capacity to help me out here? It seems like I'll have to wait a long time before anybody in the old S4TF team does.

@CodaFi
Copy link
Contributor

CodaFi commented Jan 19, 2022

I'm not sure why this PR is a work in progress. Perhaps some tests didn't pass? Try rebasing it and we can get CI running to see what its current state is.

@philipturner
Copy link
Contributor

philipturner commented Jan 19, 2022

@CodaFi I'm having trouble figuring out how to rebase. I tried making a new PR from this branch to main and there are conflicts. However, instead of letting me make a new PR, it only gives me the link to this PR. I'm probably doing something wrong because I have limited experience with Git.

main...rxwei:callee-owned-linear-map

@rxwei
Copy link
Contributor Author

rxwei commented Jan 22, 2022

Thanks for the ping. I'm giving this a shot this weekend.

Switch to `@callee_owned` callee convention for all linear map functions (differentials and pullbacks) returned from derivative functions. This reduces a half of reference counting operations in compiler-generated derivatives, and enables child linear maps that are called in linear maps to be destroyed right after the call.

Resolves rdar://71892494.
@rxwei rxwei force-pushed the callee-owned-linear-map branch from 5b3a3e3 to e8ebcc5 Compare January 23, 2022 11:15
@rxwei
Copy link
Contributor Author

rxwei commented Jan 23, 2022

Just rebased on top of main and most things are working. For anyone interested in picking up this work, the reason this was still WIP was because there are some unexpected memory leaks causing validation tests to fail.

  Swift(macosx-x86_64) :: AutoDiff/validation-test/differentiable_protocol_requirements.swift
  Swift(macosx-x86_64) :: AutoDiff/validation-test/existential.swift
  Swift(macosx-x86_64) :: AutoDiff/validation-test/forward_mode_simple.swift

For example:

[ RUN      ] ProtocolRequirementDifferentiation.func
stdout>>> check failed at /Volumes/Media/Development/Swift/swift-source/swift/test/AutoDiff/validation-test/differentiable_protocol_requirements.swift, line 55
stdout>>> expected: 0 (of type Swift.Int)
stdout>>> actual: 12 (of type Swift.Int)
stdout>>> Leaks detected: 12
[     FAIL ] ProtocolRequirementDifferentiation.func
[ RUN      ] ProtocolRequirementDifferentiation.constructor, accessor, subscript
stdout>>> check failed at /Volumes/Media/Development/Swift/swift-source/swift/test/AutoDiff/validation-test/differentiable_protocol_requirements.swift, line 117
stdout>>> expected: 0 (of type Swift.Int)
stdout>>> actual: 1 (of type Swift.Int)
stdout>>> Leaks detected: 1
[     FAIL ] ProtocolRequirementDifferentiation.constructor, accessor, subscript
ProtocolRequirementDifferentiation: Some tests failed, aborting
UXPASS: []
FAIL: ["func", "constructor, accessor, subscript"]
SKIP: []

TODO:

  • Debug and fix memory leaks.
  • Pick up file FileCheck test changes from 5b3a3e3.

@rxwei
Copy link
Contributor Author

rxwei commented Jan 23, 2022

cc @BradLarson

@philipturner
Copy link
Contributor

philipturner commented Jan 23, 2022

Thanks for investigating this! I don’t know if my current experience with the compiler is sufficient to fix the memory leak, though. Maybe after implementing the cross-import overlay.

I’d like to take on this bug myself in the future, because it’s a great stepping stone for future debugging experience. It’s small, specific and tractable. Since changing the calling convention is just an optimization, would it be okay if it got delayed because I waited to fix it?

@rxwei
Copy link
Contributor Author

rxwei commented Jan 24, 2022

@philipturner This is really not a "small, specific and tractable" task for Swift compiler beginners. It requires good understanding of SIL and the differentiation transform, because the leak is likely due to incorrectly generated code within SILGen and differentiation transform. Also, it's not "just an optimization", as it's a major ABI change.

I don't have the bandwidth to tackle this in the near future, and was hoping that @BradLarson (and team) could take a look when they need to work on further optimizations that are blocked by this. @BradLarson should also be able to suggest some starter tasks for you.

@philipturner
Copy link
Contributor

Okay, thanks for the advice! I think the cross-import overlay implementation might go quickly, so I might need starter tasks soon. In the meantime, I'll look for ways to get familiar with the Swift compiler on my own.

@asl
Copy link
Contributor

asl commented Feb 17, 2022

The SIL for generated pullback looks a bit suspicious to me:

// pullback of B.a(_:)
sil private [ossa] @$s4main1BV1ay23DifferentiationUnittest7TrackedVySfGAHFTJpSUpSr : $@convention(thin) (@in_guaranteed Tracked<Float>, @owned _AD__$s4main1BV1ay23DifferentiationUnittest7TrackedVySfGAHF_bb0__PB__src_0_wrt_0) -> @out Tracked<Float> {
// %0                                             // user: %19
// %1                                             // user: %8
// %2                                             // user: %9
bb0(%0 : $*Tracked<Float>, %1 : $*Tracked<Float>, %2 : @owned $_AD__$s4main1BV1ay23DifferentiationUnittest7TrackedVySfGAHF_bb0__PB__src_0_wrt_0):
  %3 = alloc_stack $Tracked<Float>, let, name "x", argno 1, expr op_deref // users: %22, %19, %16, %6
  %4 = witness_method $Tracked<Float>, #AdditiveArithmetic.zero!getter : <Self where Self : AdditiveArithmetic> (Self.Type) -> () -> Self : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0 // user: %6
  %5 = metatype $@thick Tracked<Float>.Type       // user: %6
  %6 = apply %4<Tracked<Float>>(%3, %5) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0
  %7 = alloc_stack $Tracked<Float>                // users: %21, %20, %12, %8
  copy_addr %1 to [initialization] %7 : $*Tracked<Float> // id: %8
  %9 = destructure_struct %2 : $_AD__$s4main1BV1ay23DifferentiationUnittest7TrackedVySfGAHF_bb0__PB__src_0_wrt_0 // user: %11
  %10 = alloc_stack $Tracked<Float>               // users: %18, %17, %16, %12
  %11 = unchecked_value_cast %9 : $@callee_guaranteed (@in_guaranteed Tracked<Float>) -> @out Tracked<Float> to $@callee_owned (@in_guaranteed Tracked<Float>) -> @out Tracked<Float> // user: %12
  %12 = apply %11(%10, %7) : $@callee_owned (@in_guaranteed Tracked<Float>) -> @out Tracked<Float> // user: %13
  destructure_tuple %12 : $()                     // id: %13
  %14 = witness_method $Tracked<Float>, #AdditiveArithmetic."+=" : <Self where Self : AdditiveArithmetic> (Self.Type) -> (inout Self, Self) -> () : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> () // user: %16
  %15 = metatype $@thick Tracked<Float>.Type      // user: %16
  %16 = apply %14<Tracked<Float>>(%3, %10, %15) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
  destroy_addr %10 : $*Tracked<Float>             // id: %17
  dealloc_stack %10 : $*Tracked<Float>            // id: %18
  copy_addr [take] %3 to [initialization] %0 : $*Tracked<Float> // id: %19
  destroy_addr %7 : $*Tracked<Float>              // id: %20
  dealloc_stack %7 : $*Tracked<Float>             // id: %21
  dealloc_stack %3 : $*Tracked<Float>             // id: %22
  %23 = tuple ()                                  // user: %24
  return %23 : $()                                // id: %24
} // end sil function '$s4main1BV1ay23DifferentiationUnittest7TrackedVySfGAHFTJpSUpSr'

Note the following:

  %11 = unchecked_value_cast %9 : $@callee_guaranteed (@in_guaranteed Tracked<Float>) -> @out Tracked<Float> to $@callee_owned (@in_guaranteed Tracked<Float>) -> @out Tracked<Float> // user: %12
  %12 = apply %11(%10, %7) : $@callee_owned (@in_guaranteed Tracked<Float>) -> @out Tracked<Float> // user: %13

So it looks like we're changing calling convention here from @callee_guaranteed to @callee_owned here. As a result no release is generated after function call and as callee assumes that caller would release the context it does not bother with releasing it by itlself. The generated LLVM IR from main here indeed have extra @swift_release call.

@rxwei
Copy link
Contributor Author

rxwei commented Feb 17, 2022

So it looks like we're changing calling convention here from @callee_guaranteed to @callee_owned here. As a result no release is generated after function call and as callee assumes that caller would release the context it does not bother with releasing it by itlself.

These pullbacks were actually @callee_owned in the first place, bitcast to @callee_guaranteed by VJPCloner (VJPCloner.cpp:679-694). The bitcast is necessary for us to be able to store the pullbacks inside an AST struct, as callee conventions cannot be specified in AST function types.

@asl
Copy link
Contributor

asl commented Feb 17, 2022

So it looks like we're changing calling convention here from @callee_guaranteed to @callee_owned here. As a result no release is generated after function call and as callee assumes that caller would release the context it does not bother with releasing it by itlself.

These pullbacks were actually @callee_owned in the first place, bitcast to @callee_guaranteed by VJPCloner (VJPCloner.cpp:679-694). The bitcast is necessary for us to be able to store the pullbacks inside an AST struct, as callee conventions cannot be specified in AST function types.

Yes, makes sense. Ok, in terms of leaked values in existential.swift, there are two allocs of "3" from:

func b(g: A) -> Tracked<Float> {
  return gradient(at: 3) { x in g.a(x) }
}

but we release only one

@rxwei
Copy link
Contributor Author

rxwei commented Feb 17, 2022

Weird... If there's nothing I missed from applying the pullback bitcasting workaround, I'd expect any leaks to trigger an OSSA verification failure. Something apparently slipped through. We can check if this 3 was captured by pullbacks.

@asl
Copy link
Contributor

asl commented Feb 17, 2022

Weird... If there's nothing I missed from applying the pullback bitcasting workaround, I'd expect any leaks to trigger an OSSA verification failure. Something apparently slipped through. We can check if this 3 was captured by pullbacks.

Things might be a bit more interesting. The corresponding Tracking<Float> should be released at the very end of b(g: A) and this is what happens in main. But not here. Stay tuned :)

@asl
Copy link
Contributor

asl commented Feb 17, 2022

So, yes. On main at the end of b(g: A) we're having just one unowned refcount for 3, we release the object and therefore deallocate it. On branch we're having additional strong refcount, so we cannot deinit it. Something is holding 3 somewhere.

@asl
Copy link
Contributor

asl commented Feb 18, 2022

So, the value is captured in the pullback of * here:

extension ${Self}
where
  T: Differentiable & SignedNumeric, T == T.Magnitude,
  T == T.TangentVector
{
  @usableFromInline
  @derivative(of: *)
  internal static func _vjpMultiply(lhs: Self, rhs: Self)
    -> (value: Self, pullback: (Self) -> (Self, Self))
  {
    return (lhs * rhs, { v in (v * rhs, v * lhs) })
  }
}

It seems like this pullback is never released.

@rxwei
Copy link
Contributor Author

rxwei commented Feb 18, 2022

Could this be an IRGen issue when its emitting a non-[callee_guaranteed] partial_apply?

It's safe to assume that this library-defined pullback is emitted correctly since its closure lowering is not changed by this PR. Library-defined pullbacks are being captured by parent pullbacks generated by the differentiation transform via non-[callee_guaranteed] partial_applys.

@asl
Copy link
Contributor

asl commented Feb 18, 2022

Could this be an IRGen issue when its emitting a non -[callee_guaranteed] partial_apply`?

The partial apply is [callee_guaranteed] here, yes. And it's pretty same in main, the values are also captured there.

@rxwei
Copy link
Contributor Author

rxwei commented Feb 18, 2022

Could this be an IRGen issue when its emitting a non -[callee_guaranteed] partial_apply`?

The partial apply is [callee_guaranteed] here, yes. And it's pretty same in main, the values are also captured there.

Right, this particular partial_apply (library defined VJP) is [callee_guaranteed] and I don't think there's any issues with it. But the resulting pullback may be captured by outer pullbacks using a non-[callee_guaranteed] partial_apply, forming @callee_owned closures.

@asl
Copy link
Contributor

asl commented Feb 18, 2022

Could this be an IRGen issue when its emitting a non -[callee_guaranteed] partial_apply`?

The partial apply is [callee_guaranteed] here, yes. And it's pretty same in main, the values are also captured there.

Right, this particular partial_apply (library defined VJP) is [callee_guaranteed] and I don't think there's any issues with it. But the resulting pullback may be captured by outer pullbacks using a non-[callee_guaranteed] partial_apply, forming @callee_owned closures.

Yes, it's one level up:

// reverse-mode derivative of static Tracked<A>.* infix(_:_:)
sil [thunk] [always_inline] [ossa] @$s23DifferentiationUnittest7TrackedVAASjRzlE1moiyACyxGAE_AEtFZs13SignedNumericRz01_A014DifferentiableRz9MagnitudeSjQzRsz13TangentVectorAgHPQzAJRSlTJrSSUpSr : $@convention(method) <τ_0_0 where τ_0_0 : SignedNumeric, τ_0_0 : Differentiable, τ_0_0 == τ_0_0.Magnitude, τ_0_0.Magnitude == τ_0_0.TangentVector> (@in_guaranteed Tracked<τ_0_0>, @in_guaranteed Tracked<τ_0_0>, @thin Tracked<τ_0_0>.Type) -> (@out Tracked<τ_0_0>, @owned @callee_owned @substituted <τ_0_0, τ_0_1, τ_0_2> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @out τ_0_2) for <Tracked<τ_0_0.TangentVector>, Tracked<τ_0_0.TangentVector>, Tracked<τ_0_0.TangentVector>>) {
// %0                                             // user: %5
// %1                                             // user: %5
// %2                                             // user: %5
// %3                                             // user: %5
bb0(%0 : $*Tracked<τ_0_0>, %1 : $*Tracked<τ_0_0>, %2 : $*Tracked<τ_0_0>, %3 : $@thin Tracked<τ_0_0>.Type):
  // function_ref static Tracked<A>._vjpMultiply(lhs:rhs:)
  %4 = function_ref @$s23DifferentiationUnittest7TrackedVAAs13SignedNumericRz01_A014DifferentiableRz9MagnitudeSjQzRsz13TangentVectorAeFPQzAHRSlE12_vjpMultiply3lhs3rhsACyxG5value_AO_AOtAOc8pullbacktAO_AOtFZ : $@convention(method) <τ_0_0 where τ_0_0 : SignedNumeric, τ_0_0 : Differentiable, τ_0_0 == τ_0_0.Magnitude, τ_0_0.Magnitude == τ_0_0.TangentVector> (@in_guaranteed Tracked<τ_0_0>, @in_guaranteed Tracked<τ_0_0>, @thin Tracked<τ_0_0>.Type) -> (@out Tracked<τ_0_0>, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2, τ_0_3, τ_0_4, τ_0_5 where τ_0_0 == τ_0_1, τ_0_2 == τ_0_3, τ_0_4 == τ_0_5> (@in_guaranteed Tracked<τ_0_0>) -> (@out Tracked<τ_0_2>, @out Tracked<τ_0_4>) for <τ_0_0, τ_0_0, τ_0_0, τ_0_0, τ_0_0, τ_0_0>) // user: %5
  %5 = apply %4<τ_0_0>(%0, %1, %2, %3) : $@convention(method) <τ_0_0 where τ_0_0 : SignedNumeric, τ_0_0 : Differentiable, τ_0_0 == τ_0_0.Magnitude, τ_0_0.Magnitude == τ_0_0.TangentVector> (@in_guaranteed Tracked<τ_0_0>, @in_guaranteed Tracked<τ_0_0>, @thin Tracked<τ_0_0>.Type) -> (@out Tracked<τ_0_0>, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2, τ_0_3, τ_0_4, τ_0_5 where τ_0_0 == τ_0_1, τ_0_2 == τ_0_3, τ_0_4 == τ_0_5> (@in_guaranteed Tracked<τ_0_0>) -> (@out Tracked<τ_0_2>, @out Tracked<τ_0_4>) for <τ_0_0, τ_0_0, τ_0_0, τ_0_0, τ_0_0, τ_0_0>) // user: %6
  %6 = convert_function %5 : $@callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2, τ_0_3, τ_0_4, τ_0_5 where τ_0_0 == τ_0_1, τ_0_2 == τ_0_3, τ_0_4 == τ_0_5> (@in_guaranteed Tracked<τ_0_0>) -> (@out Tracked<τ_0_2>, @out Tracked<τ_0_4>) for <τ_0_0, τ_0_0, τ_0_0, τ_0_0, τ_0_0, τ_0_0> to $@callee_guaranteed (@in_guaranteed Tracked<τ_0_0>) -> (@out Tracked<τ_0_0>, @out Tracked<τ_0_0>) // user: %8
  // function_ref thunk for @escaping @callee_guaranteed (@in_guaranteed Tracked<A>) -> (@out Tracked<A>, @out Tracked<A>)
  %7 = function_ref @$s23DifferentiationUnittest7TrackedVyxGA2DIegnrr_A3DIexnrr_s13SignedNumericRz01_A014DifferentiableRz9MagnitudeSjQzRsz13TangentVectorAfGPQzAIRSlTR : $@convention(thin) <τ_0_0 where τ_0_0 : SignedNumeric, τ_0_0 : Differentiable, τ_0_0 == τ_0_0.Magnitude, τ_0_0.Magnitude == τ_0_0.TangentVector> (@in_guaranteed Tracked<τ_0_0>, @guaranteed @callee_guaranteed (@in_guaranteed Tracked<τ_0_0>) -> (@out Tracked<τ_0_0>, @out Tracked<τ_0_0>)) -> (@out Tracked<τ_0_0>, @out Tracked<τ_0_0>) // user: %8
  %8 = partial_apply %7<τ_0_0>(%6) : $@convention(thin) <τ_0_0 where τ_0_0 : SignedNumeric, τ_0_0 : Differentiable, τ_0_0 == τ_0_0.Magnitude, τ_0_0.Magnitude == τ_0_0.TangentVector> (@in_guaranteed Tracked<τ_0_0>, @guaranteed @callee_guaranteed (@in_guaranteed Tracked<τ_0_0>) -> (@out Tracked<τ_0_0>, @out Tracked<τ_0_0>)) -> (@out Tracked<τ_0_0>, @out Tracked<τ_0_0>) // user: %9
  %9 = convert_function %8 : $@callee_owned (@in_guaranteed Tracked<τ_0_0>) -> (@out Tracked<τ_0_0>, @out Tracked<τ_0_0>) to $@callee_owned @substituted <τ_0_0, τ_0_1, τ_0_2> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @out τ_0_2) for <Tracked<τ_0_0>, Tracked<τ_0_0>, Tracked<τ_0_0>> // user: %10
  return %9 : $@callee_owned @substituted <τ_0_0, τ_0_1, τ_0_2> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @out τ_0_2) for <Tracked<τ_0_0>, Tracked<τ_0_0>, Tracked<τ_0_0>> // id: %10
} // end sil function '$s23DifferentiationUnittest7TrackedVAASjRzlE1moiyACyxGAE_AEtFZs13SignedNumericRz01_A014DifferentiableRz9MagnitudeSjQzRsz13TangentVectorAgHPQzAJRSlTJrSSUpSr'

On main we're having callee_guaranteed closure here. And I do not see any differences in generated LLVM IR here.

@rxwei
Copy link
Contributor Author

rxwei commented Feb 18, 2022

What about the generated partial application forwarders in LLVM IR?

@asl
Copy link
Contributor

asl commented May 19, 2022

@rxwei I rebased this PR into main and ran into bunch of ABI-compatibility assertions. Looks like instead of reabstraction thunks (removed here: https://github.com/apple/swift/pull/34935/files#diff-01cf87f81a8c47d84be8508fac2cb0f1a4ba15919ad86d3c8904bccee60151b5L915) this PR emits just unchecked bitcasts.

However, the bitcasts are now from / to ABI-incompatible types, e.g. from
$@callee_owned (@in_guaranteed Generic<Float>.TangentVector) -> (@out Float, @out Float, @out Float) to
$@callee_guaranteed (Generic<Float>.TangentVector) -> (Float, Float, Float). This certainly will not work.

Returning reabstraction thunks back resolved the issues (though leaks remained, but this is a separate issue) :)

@rxwei
Copy link
Contributor Author

rxwei commented May 19, 2022

However, the bitcasts are now from / to ABI-incompatible types, e.g. from
$@callee_owned (@in_guaranteed Generic<Float>.TangentVector) -> (@out Float, @out Float, @out Float) to
$@callee_guaranteed (Generic<Float>.TangentVector) -> (Float, Float, Float). This certainly will not work.

That's a bit surprising though, as we should only be changing the callee convention, not the result abstraction.

@asl
Copy link
Contributor

asl commented May 19, 2022

However, the bitcasts are now from / to ABI-incompatible types, e.g. from
$@callee_owned (@in_guaranteed Generic<Float>.TangentVector) -> (@out Float, @out Float, @out Float) to
$@callee_guaranteed (Generic<Float>.TangentVector) -> (Float, Float, Float). This certainly will not work.

That's a bit surprising though, as we should only be changing the callee convention, not the result abstraction.

Yes. And there were nothing like this in my previous rebase...

@rxwei
Copy link
Contributor Author

rxwei commented May 19, 2022

I'd recommend debugging this with the previous commit. My understanding from @gottesmm was that @callee_owned convention hasn't been used for a few years, so there may be bugs in IRGen.

@asl
Copy link
Contributor

asl commented May 25, 2022

@rxwei Yes, this is the plan. I'm just documenting some things that might help us :)

@asl
Copy link
Contributor

asl commented May 26, 2022

@rxwei Ok, the problem exists in main as well :) What happens: for captured stuff we're having the following pullback type: $@callee_guaranteed (Float, @inout_aliasable Float) -> Float. The lowered pullback type is $@callee_guaranteed (Float, @inout Float) -> Float (apparently aliaseable bit is silently ignored by lowering, it simply creates inout parameter convention).

On main we're simply emitting a reabstraction thunk for pullback conversion. On branch instead we're doing a bitcast. And now the assertion is triggered because it thinks that these two types ($@callee_guaranteed (Float, @inout_aliasable Float) -> Float and $@callee_guaranteed (Float, @inout Float) -> Float) are not ABI-compatible.

We started to handle captured arguments only very recently, this is why this issue did not occur previously.

@asl
Copy link
Contributor

asl commented Jun 8, 2022

There are multiple issues here actually. Few of them are related directly to this PR and I fixed them locally (some are just typos and others are related to the recent changes in autodiff code). The root cause is somewhere around reabstraction thunks from @callee_owend to @callee_guaranteed and the corresponding partial apply forwarders. Bad news is that we're leaking much more than reported by tests. We are leaking the whole context of one of functions and everything that was captured there. We were just lucky that it triggered in one of these tests.

I managed to significantly reduce one of testcases stripping layers and layers of abstractions :) Here is the piece of optimized LLVM IR showing the issue:

  %13 = tail call noalias %swift.refcounted* @swift_allocObject(%swift.type* getelementptr inbounds (%swift.full_boxmetadata, %swift.full_boxmetadata* @metadata.75, i64 0, i32 2), i64 32, i64 7) #4, !noalias !101
  %14 = getelementptr inbounds %swift.refcounted, %swift.refcounted* %13, i64 1
  %.fn.i.i = bitcast %swift.refcounted* %14 to i8**
  store i8* %11, i8** %.fn.i.i, align 8, !noalias !101
  %.data.i.i = getelementptr inbounds %swift.refcounted, %swift.refcounted* %13, i64 1, i32 1
  %15 = bitcast i64* %.data.i.i to %swift.refcounted**
  store %swift.refcounted* %12, %swift.refcounted** %15, align 8, !noalias !101
  %16 = bitcast i8* %11 to float (float, %swift.refcounted*)*
  %17 = tail call %swift.refcounted* @swift_retain(%swift.refcounted* returned %12) #4, !noalias !108
  %18 = tail call swiftcc float %16(float 1.000000e+00, %swift.refcounted* swiftself %12) #17, !noalias !112
  tail call swiftcc void @"$s5main23fooyySf_SftF"(float %10, float %18)
  tail call void @swift_release(%swift.refcounted* %0) #4
  tail call void @swift_release(%swift.refcounted* %3) #4
  tail call void @swift_release(%swift.refcounted* %6) #4
  ret void

Here %16 is a differential call. Note that %13 is never released and %12 is not consumed as well due to extra retain. This retain is from @callee_owned => @callee_guaranteed reabstraction thunk (pretty reasonable). The differential here would consume the context from %12 normally. However we're missing a release of %13 here.

@asl
Copy link
Contributor

asl commented Jun 9, 2022

Ok, so, I'm a bit confused. This is what we're having here: a reabstraction thunk to convert from @callee_owned to @callee_guaranteed. With some additional closures on top of this.

This is how everything is organized:

...
  // function_ref thunk for @callee_guaranteed (@unowned Float) -> (@unowned Float, @owned @escaping @callee_owned (@unowned Float) -> (@unowned Float))
  %20 = function_ref @$sS4fIexyd_Igydo_S2fxq_r0_lyS2fIsegnr_Iegnro_TR : $@convention(thin) (@in_guaranteed Float, @noescape @callee_guaranteed (Float) -> (Float, @owned @callee_owned (Float) -> Float)) -> (@out Float, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <Float, Float>), loc "forward_mode_simple2.swift":52:64, scope 33 // user: %21
  %21 = partial_apply [callee_guaranteed] %20(%19) : $@convention(thin) (@in_guaranteed Float, @noescape @callee_guaranteed (Float) -> (Float, @owned @callee_owned (Float) -> Float)) -> (@out Float, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <Float, Float>), loc "forward_mode_simple2.swift":52:64, scope 33 // user: %22
 %22 = convert_function %21 : $@callee_guaranteed (@in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <Float, Float>) to $@callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2, τ_0_3> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <τ_0_2, τ_0_3>) for <Float, Float, Float, Float>, loc "forward_mode_simple2.swift":52:64, scope 33 // users: %54, %23
 ...
  strong_release %22 : $@callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2, τ_0_3> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <τ_0_2, τ_0_3>) for <Float, Float, Float, Float>, loc "<compiler-generated>":0:0, scope 33 // id: %54
...

and thunk itself is:

// thunk for @callee_guaranteed (@unowned Float) -> (@unowned Float)
sil shared [transparent] [reabstraction_thunk] @$sS2fIgyd_S2fIegnr_TR : $@convention(thin) (@in_guaranteed Float, @noescape @callee_guaranteed (Float) -> Float) -> @out Float {
// %0                                             // user: %5
// %1                                             // user: %3
// %2                                             // user: %4
bb0(%0 : $*Float, %1 : $*Float, %2 : $@noescape @callee_guaranteed (Float) -> Float):
  %3 = load %1 : $*Float, loc "<compiler-generated>":0:0, scope 34 // user: %4
  %4 = apply %2(%3) : $@noescape @callee_guaranteed (Float) -> Float, loc "<compiler-generated>":0:0, scope 34 // user: %5
  store %4 to %0 : $*Float, loc "<compiler-generated>":0:0, scope 34 // id: %5
  %6 = tuple (), loc "<compiler-generated>":0:0, scope 34 // user: %7
  return %6 : $(), loc "<compiler-generated>":0:0, scope 34 // id: %7
} // end sil function '$sS2fIgyd_S2fIegnr_TR'

// thunk for @callee_guaranteed (@unowned Float) -> (@unowned Float, @owned @escaping @callee_owned (@unowned Float) -> (@unowned Float))
sil shared [transparent] [reabstraction_thunk] @$sS4fIexyd_Igydo_S2fxq_r0_lyS2fIsegnr_Iegnro_TR : $@convention(thin) (@in_guaranteed Float, @noescape @callee_guaranteed (Float) -> (Float, @owned @callee_owned (Float) -> Float)) -> (@out Float, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <Float, Float>) {
// %0                                             // user: %7
// %1                                             // user: %3
// %2                                             // user: %4
bb0(%0 : $*Float, %1 : $*Float, %2 : $@noescape @callee_guaranteed (Float) -> (Float, @owned @callee_owned (Float) -> Float)):
  %3 = load %1 : $*Float, loc "<compiler-generated>":0:0, scope 35 // user: %4
  %4 = apply %2(%3) : $@noescape @callee_guaranteed (Float) -> (Float, @owned @callee_owned (Float) -> Float), loc "<compiler-generated>":0:0, scope 35 // users: %6, %5
  %5 = tuple_extract %4 : $(Float, @callee_owned (Float) -> Float), 0, loc "<compiler-generated>":0:0, scope 35 // user: %7
  %6 = tuple_extract %4 : $(Float, @callee_owned (Float) -> Float), 1, loc "<compiler-generated>":0:0, scope 35 // user: %9
  store %5 to %0 : $*Float, loc "<compiler-generated>":0:0, scope 35 // id: %7
  // function_ref thunk for @escaping @callee_owned (@unowned Float) -> (@unowned Float)
  %8 = function_ref @$sS2fIexyd_S2fIegnr_TR : $@convention(thin) (@in_guaranteed Float, @owned @callee_owned (Float) -> Float) -> @out Float, loc "<compiler-generated>":0:0, scope 35 // user: %9
  %9 = partial_apply [callee_guaranteed] %8(%6) : $@convention(thin) (@in_guaranteed Float, @owned @callee_owned (Float) -> Float) -> @out Float, loc "<compiler-generated>":0:0, scope 35 // user: %10
  %10 = convert_function %9 : $@callee_guaranteed (@in_guaranteed Float) -> @out Float to $@callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <Float, Float>, loc "<compiler-generated>":0:0, scope 35 // user: %11
  return %10 : $@callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <Float, Float>, loc "<compiler-generated>":0:0, scope 35 // id: %11
} // end sil function '$sS4fIexyd_Igydo_S2fxq_r0_lyS2fIsegnr_Iegnro_TR'

The conversion is essentially done via a closure. The outer closure (of the thunk) is correctly released. However, where and how the inner one (%9 in the thunk) is supposed to be released / consumed? The LLVM IR looks like as follows:

define linkonce_odr hidden swiftcc { i8*, %swift.refcounted* } @"$sS4fIexyd_Igydo_S2fxq_r0_lyS2fIsegnr_Iegnro_TR"(%TSf* noalias nocapture %0, %TSf* noalias nocapture dereferenceable(4) %1, i8* %2, %swift.opaque* %3) #0 !dbg !314 {
entry:
  %._value = getelementptr inbounds %TSf, %TSf* %1, i32 0, i32 0, !dbg !319
  %4 = load float, float* %._value, align 4, !dbg !319
  %5 = bitcast i8* %2 to { float, i8*, %swift.refcounted* } (float, %swift.refcounted*)*, !dbg !319
  %6 = bitcast %swift.opaque* %3 to %swift.refcounted*, !dbg !319
  %7 = call swiftcc { float, i8*, %swift.refcounted* } %5(float %4, %swift.refcounted* swiftself %6) #17, !dbg !319
  %8 = extractvalue { float, i8*, %swift.refcounted* } %7, 0, !dbg !319
  %9 = extractvalue { float, i8*, %swift.refcounted* } %7, 1, !dbg !319
  %10 = extractvalue { float, i8*, %swift.refcounted* } %7, 2, !dbg !319
  %._value1 = getelementptr inbounds %TSf, %TSf* %0, i32 0, i32 0, !dbg !319
  store float %8, float* %._value1, align 4, !dbg !319
  %11 = call noalias %swift.refcounted* @swift_allocObject(%swift.type* getelementptr inbounds (%swift.full_boxmetadata, %swift.full_boxmetadata* @metadata.50, i32 0, i32 2), i64 32, i64 7) #5, !dbg !319
  %12 = bitcast %swift.refcounted* %11 to <{ %swift.refcounted, %swift.function }>*, !dbg !319
  %13 = getelementptr inbounds <{ %swift.refcounted, %swift.function }>, <{ %swift.refcounted, %swift.function }>* %12, i32 0, i32 1, !dbg !319
  %.fn = getelementptr inbounds %swift.function, %swift.function* %13, i32 0, i32 0, !dbg !319
  store i8* %9, i8** %.fn, align 8, !dbg !319
  %.data = getelementptr inbounds %swift.function, %swift.function* %13, i32 0, i32 1, !dbg !319
  store %swift.refcounted* %10, %swift.refcounted** %.data, align 8, !dbg !319
  %14 = insertvalue { i8*, %swift.refcounted* } { i8* bitcast (void (%TSf*, %TSf*, %swift.refcounted*)* @"$sS2fIexyd_S2fIegnr_TRTA" to i8*), %swift.refcounted* undef }, %swift.refcounted* %11, 1, !dbg !319
  ret { i8*, %swift.refcounted* } %14, !dbg !319
}

So, here the fresh context object (%11) is created, but the resulting value is never consumed / released.

Tagging @gottesmm for some guidance as the reabstraction thunk code is not autodiff-specific here :)

@asl
Copy link
Contributor

asl commented Jul 14, 2022

@compnerd @gottesmm Do you happen to have some ideas what might be wrong with reabstractions thunks in #34935 (comment) ? Thanks!

@asl
Copy link
Contributor

asl commented Feb 9, 2023

@rxwei Wanted to resurrect this. Will you please remind why switching from structs to tuples is the crucial here?

@rxwei
Copy link
Contributor Author

rxwei commented Feb 9, 2023

AST function types don't support the @callee_owned attribute. They are always implicitly @callee_guaranteed. With structs, in this PR I made a hack to bitcast those closures to @callee_guaranteed in order to compute their AST types. Things may have gone wrong there.

After switching to tuple, you can now use @callee_owned closures directly as tuple elements.

@asl
Copy link
Contributor

asl commented Feb 9, 2023

Well, I'm afraid we'd still need to use it. As tuples also use AST function types. Otherwise various optimization passes (e.g. specializer) do assert – they cannot, for example, specialize lowered function types...

@rxwei
Copy link
Contributor Author

rxwei commented Feb 9, 2023

Well, I'm afraid we'd still need to use it. As tuples also use AST function types. Otherwise various optimization passes (e.g. specializer) do assert – they cannot, for example, specialize lowered function types...

Now that I try to remember, by "tuple" what I really meant was for linear map struct elements to be directly added to the VJP function's result list and to the pullback function's parameter list, i.e. no nested tuples. That way I don't think there's supposed to be any attempts to reconstruct AST types.

@asl
Copy link
Contributor

asl commented Feb 9, 2023

Ah, so we'd essentially unwrap the current tuples, right?

@rxwei
Copy link
Contributor Author

rxwei commented Feb 9, 2023

Yes. It'd be good to confirm the max parameter limit before committing to this, since our tuple size basically grows linearly with respect to the number of instructions (more specifically, apply)...

@asl
Copy link
Contributor

asl commented Feb 9, 2023

Yes. It'd be good to confirm the max parameter limit before committing to this, since our tuple size basically grows linearly with respect to the number of instructions (more specifically, apply)...

Right. However, we very complex cases we can fallback to the wrapped tuple implementation. Essentially we know the size of tuple, so we can decide whether we'd want to unwrap it or not.

@rxwei
Copy link
Contributor Author

rxwei commented Feb 10, 2023

Tuples have a limit too. I forgot the number

@asl
Copy link
Contributor

asl commented Feb 10, 2023

Tuples have a limit too. I forgot the number

32 bits are used to stored the # of elements. I think we're ok here :) # of function parameters are 32 bits as well. Results use 16 bits for the count.

@asl
Copy link
Contributor

asl commented Feb 25, 2023

So, I'm afraid the things are little bit more complicated:

  1. We need to use AST types in branch trace enums => cannot use @callee_owned here
  2. We need to use AST types in other various places mostly connected with generics (e.g. specializer could only work on AST types, not SIL types)
  3. We cannot easily bitcast as lowered pullback type might not be ABI-compatible with actual type. So, proper reabstraction is necessary
  4. There are also various checks here and there that demands lowered abstract type (mostly around witness methods and class methods)

Also, even if we'd introduce reabstraction, the leaks do no go away, so I'm afraid @callee_owned => @callee_guaranteed reabstraction is a bit buggy at LLVM IR emission level ...

@marcrasi marcrasi removed their request for review June 6, 2023 21:07
@asl asl added the AutoDiff label Aug 23, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants