Skip to content

Commit 3b44a60

Browse files
committed
[AutoDiff] Fix a 'partial_apply' leak caused by subset parameters thunks.
In `ADContext::promoteToDifferentiableFunction` when we emit a subset parameters thunk, the closure produced by `emitAssociatedFunctionReference` is not being used or released. This patch fixes that by releasing the closure. In a future design, we should determine the actual parameter indices and whether thunking is needed before emitting an unused closure. This shall be addressed in a future patch. A value leak checking test was added for protocols because this bug was originally discovered in eaplatanios/swift-ale#1 by using the newly split `Layer` protocol. However, value leak checking tests do not test for closure (`partial_apply`) leaks. The proper setup for catching closure leaks is moving AD before ownership stripping (swiftlang#26157).
1 parent b2720a3 commit 3b44a60

File tree

3 files changed

+76
-3
lines changed

3 files changed

+76
-3
lines changed

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6607,7 +6607,7 @@ SILValue ADContext::promoteToDifferentiableFunction(
66076607
getGeneratedAssociatedFunctionReferences().push_back(assocFn);
66086608

66096609
// If desired indices are a subset of actual indices, create a "subset
6610-
// indices thunk".
6610+
// indices thunk" and destroy the emitted associated function reference.
66116611
// - For JVPs: the thunked JVP returns a differential taking fewer
66126612
// parameters (using `.zero` for the dropped parameters).
66136613
// - For VJPs: the thunked VJP returns a pullback that drops the unused
@@ -6621,6 +6621,9 @@ SILValue ADContext::promoteToDifferentiableFunction(
66216621
getASTContext(), actualIndices.parameters->getCapacity());
66226622
if (actualIndices.source != desiredIndices.source ||
66236623
!actualIndices.parameters->equals(extendedDesiredIndices)) {
6624+
// Destroy the already emitted associated function reference because it
6625+
// is no longer used.
6626+
builder.emitReleaseValueAndFold(loc, assocFn);
66246627
// Check if underlying original function reference has been partially
66256628
// applied with arguments. If so, produce an error: parameter subset
66266629
// thunks do not yet support this case because partially applied arguments

test/AutoDiff/leakchecking.swift

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,10 +74,14 @@ extension DummyLayer {
7474
func defaultImpl(_ input: Input) -> Output {
7575
return requirement(input)
7676
}
77-
func vjpDefaultImpl(_ input: Input) -> (Output, (Self.Output.TangentVector) -> (Self.TangentVector, Self.Input.TangentVector)) {
78-
return Swift.valueWithPullback(at: self, input) { (m, i) in m.requirement(i) }
77+
func vjpDefaultImpl(_ input: Input)
78+
-> (Output,
79+
(Self.Output.TangentVector)
80+
-> (Self.TangentVector, Self.Input.TangentVector)) {
81+
return Swift.valueWithPullback(at: self, input) { $0.requirement($1) }
7982
}
8083
}
84+
8185
LeakCheckingTests.testWithLeakChecking("TestProtocolDefaultDerivative") {
8286
struct Foo : DummyLayer {
8387
// NOTE: Make sure not to override `defaultImpl`.
@@ -98,6 +102,40 @@ LeakCheckingTests.testWithLeakChecking("TestProtocolDefaultDerivative") {
98102
}
99103
}
100104

105+
protocol Module : Differentiable {
106+
associatedtype Input
107+
associatedtype Output : Differentiable
108+
@differentiable(wrt: self)
109+
func callAsFunction(_ input: Input) -> Output
110+
}
111+
protocol Layer : Module where Input : Differentiable {
112+
@differentiable(wrt: (self, input))
113+
func callAsFunction(_ input: Input) -> Output
114+
}
115+
116+
LeakCheckingTests.testWithLeakChecking("ProtocolRequirements") {
117+
struct Dense: Layer {
118+
var w = Tracked<Float>(1)
119+
@differentiable
120+
func callAsFunction(_ input: Tracked<Float>) -> Tracked<Float> {
121+
input * w
122+
}
123+
}
124+
struct Model: Module {
125+
var dense1 = Dense()
126+
var dense2 = Dense()
127+
@differentiable
128+
func callAsFunction(_ input: Tracked<Int>) -> Tracked<Float> {
129+
dense2(dense1(Tracked(Float(input.value))))
130+
}
131+
}
132+
let x = Tracked<Int>(1)
133+
let model = Model()
134+
_ = model.valueWithGradient { model in
135+
model(x)
136+
}
137+
}
138+
101139
LeakCheckingTests.testWithLeakChecking("LetStructs") {
102140
func structConstructionWithOwnedParams(_ x: Tracked<Float>) -> Tracked<Float> {
103141
let z = Tracked(x)
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
// RUN: %target-swift-frontend -emit-sil %s | %FileCheck %s
2+
3+
@differentiable(where T: Differentiable)
4+
func foo<T: Numeric>(_ x: T, _ y: T) -> T { x * y }
5+
6+
@differentiating(foo)
7+
func foo_vjp<T: Numeric & Differentiable>(_ x: T, _ y: T) -> (value: T, pullback: (T.TangentVector) -> (T.TangentVector, T.TangentVector)) {
8+
(foo(x, y), { _ in (.zero, .zero) })
9+
}
10+
11+
let x = Float(1)
12+
@differentiable
13+
func differentiate_foo_wrt_0(_ x: Float) -> Float {
14+
foo(x, 1)
15+
}
16+
17+
// CHECK-LABEL: @{{.*}}differentiate_foo_wrt_0{{.*}}__vjp
18+
// CHECK: bb0
19+
// CHECK: [[FOO_ORIG:%.*]] = function_ref @{{.*}}foo{{.*}} : $@convention(thin) <τ_0_0 where τ_0_0 : Numeric> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_0) -> @out τ_0_0
20+
// CHECK: [[FOO_FLOAT:%.*]] = partial_apply [callee_guaranteed] [[FOO_ORIG]]<Float>() : $@convention(thin) <τ_0_0 where τ_0_0 : Numeric> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_0) -> @out τ_0_0
21+
// CHECK: [[FOO_JVP:%.*]] = function_ref @AD__{{.*}}foo{{.*}}__jvp_src_0_wrt_0_1 : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable, τ_0_0 : Numeric> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_0) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector, @in_guaranteed τ_0_0.TangentVector) -> @out τ_0_0.TangentVector)
22+
// CHECK: [[FOO_JVP_FLOAT:%.*]] = partial_apply [callee_guaranteed] [[FOO_JVP]]<Float>() : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable, τ_0_0 : Numeric> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_0) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector, @in_guaranteed τ_0_0.TangentVector) -> @out τ_0_0.TangentVector)
23+
// CHECK: release_value [[FOO_JVP_FLOAT]]
24+
// CHECK: [[FOO_JVP_SUBSET_THUNK_THIN:%.*]] = function_ref @AD__orig_{{.*}}foo{{.*}}_src_0_wrt_0_jvp_subset_parameters_thunk : $@convention(thin) (@in_guaranteed Float, @in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float)
25+
// CHECK: [[FOO_JVP_SUBSET_THUNK:%.*]] = thin_to_thick_function [[FOO_JVP_SUBSET_THUNK_THIN]] : $@convention(thin) (@in_guaranteed Float, @in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float) to $@callee_guaranteed (@in_guaranteed Float, @in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float)
26+
// CHECK: [[FOO_VJP:%.*]] = function_ref @{{.*}}foo_vjp{{.*}} : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable, τ_0_0 : Numeric> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_0) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector) -> (@out τ_0_0.TangentVector, @out τ_0_0.TangentVector))
27+
// CHECK: [[FOO_VJP_FLOAT:%.*]] = partial_apply [callee_guaranteed] [[FOO_VJP]]<Float>() : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable, τ_0_0 : Numeric> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_0) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector) -> (@out τ_0_0.TangentVector, @out τ_0_0.TangentVector))
28+
// CHECK: release_value [[FOO_VJP_FLOAT]]
29+
// CHECK: [[FOO_VJP_SUBSET_THUNK_THIN:%.*]] = function_ref @AD__orig_{{.*}}foo{{.*}}_src_0_wrt_0_vjp_subset_parameters_thunk : $@convention(thin) (@in_guaranteed Float, @in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float)
30+
// CHECK: [[FOO_VJP_SUBSET_THUNK:%.*]] = thin_to_thick_function [[FOO_VJP_SUBSET_THUNK_THIN]] : $@convention(thin) (@in_guaranteed Float, @in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float) to $@callee_guaranteed (@in_guaranteed Float, @in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float)
31+
// CHECK: [[FOO_DIFF:%.*]] = autodiff_function [wrt 0] [order 1] [[FOO_FLOAT]] : $@callee_guaranteed (@in_guaranteed Float, @in_guaranteed Float) -> @out Float with {[[FOO_JVP_SUBSET_THUNK]] : $@callee_guaranteed (@in_guaranteed Float, @in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float), [[FOO_VJP_SUBSET_THUNK]] : $@callee_guaranteed (@in_guaranteed Float, @in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float)}
32+
// CHECK: }

0 commit comments

Comments
 (0)