Skip to content

Commit d38f69b

Browse files
committed
[AutoDiff] Handle @inout_aliasable when converting associated functions.
When reapplying a `partial_apply` for a JVP/VJP, if the argument has convention `@inout_aliasable`, it is a noescape mutable capture and the underlying value should not be retained. This fixes a memory leak when a noescape closure being differentiated captures a mutable self.
1 parent b2720a3 commit d38f69b

File tree

3 files changed

+61
-9
lines changed

3 files changed

+61
-9
lines changed

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1951,19 +1951,35 @@ reapplyFunctionConversion(SILValue newFunc, SILValue oldFunc,
19511951
builder.createDeallocStack(loc, alloc);
19521952
};
19531953
for (auto arg : pai->getArguments()) {
1954-
// Retain the argument since it's to be owned by the newly created
1954+
// Retain the argument if it's to be owned by the newly created
19551955
// closure.
1956+
// Objects are to be retained.
19561957
if (arg->getType().isObject()) {
19571958
builder.createRetainValue(loc, arg, builder.getDefaultAtomicity());
19581959
newArgs.push_back(arg);
1959-
} else if (arg->getType().isLoadable(builder.getFunction())) {
1960-
builder.createRetainValueAddr(loc, arg, builder.getDefaultAtomicity());
1961-
newArgs.push_back(arg);
1962-
} else {
1963-
auto *argCopy = builder.createAllocStack(loc, arg->getType());
1964-
copiedIndirectParams.push_back(argCopy);
1965-
builder.createCopyAddr(loc, arg, argCopy, IsNotTake, IsInitialization);
1966-
newArgs.push_back(argCopy);
1960+
}
1961+
// Addresses depend on argument conventions.
1962+
else {
1963+
auto conv = pai->getCalleeFunction()->getConventions();
1964+
auto argConv =
1965+
conv.getSILArgumentConvention(conv.getNumSILArguments() - 1);
1966+
// If the argument is an aliasable inout reference, do not retain the
1967+
// argument since it's a `@noescape` capture.
1968+
if (argConv == SILArgumentConvention::Indirect_InoutAliasable) {
1969+
newArgs.push_back(arg);
1970+
}
1971+
// Otherwise, retain/copy the underlying value.
1972+
else if (arg->getType().isLoadable(builder.getFunction())) {
1973+
builder.createRetainValueAddr(loc, arg,
1974+
builder.getDefaultAtomicity());
1975+
newArgs.push_back(arg);
1976+
} else {
1977+
auto *argCopy = builder.createAllocStack(loc, arg->getType());
1978+
copiedIndirectParams.push_back(argCopy);
1979+
builder.createCopyAddr(loc, arg, argCopy, IsNotTake,
1980+
IsInitialization);
1981+
newArgs.push_back(argCopy);
1982+
}
19671983
}
19681984
}
19691985
auto innerNewFunc = reapplyFunctionConversion(

test/AutoDiff/closures.swift

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,27 @@ func diffableClosureInStruct(s: Foo) {
1313
// CHECK: retain_value [[CLOSURE]] : $@differentiable @callee_guaranteed (Float) -> Float
1414
// CHECK: autodiff_function_extract [original] [[CLOSURE]] : $@differentiable @callee_guaranteed (Float) -> Float
1515

16+
struct InoutAliasableCapture {
17+
var x: Float = .zero
18+
mutating func foo() {
19+
func capturesMutableSelf(t: Float) -> Float {
20+
self.x = .zero
21+
return t
22+
}
23+
_ = gradient(at: .zero, in: capturesMutableSelf)
24+
}
25+
}
26+
27+
// CHECK-LABEL: @{{.*}}InoutAliasableCapture{{.*}}foo{{.*}} : $@convention(method) (@inout InoutAliasableCapture) -> () {
28+
// CHECK: bb0([[SELF:%.*]] : $*InoutAliasableCapture):
29+
// CHECK: [[JVP:%.*]] = function_ref @{{.*}}capturesMutableSelf{{.*}}__jvp_src_0_wrt_0 : $@convention(thin) (Float, @inout_aliasable InoutAliasableCapture) -> (Float, @owned @callee_guaranteed (Float) -> Float)
30+
// CHECK-NOT: retain_value_addr [[SELF]]
31+
// CHECK-NOT: copy_addr [[SELF]]
32+
// CHECK: [[JVP_CAPTURED:%.*]] = partial_apply [callee_guaranteed] [[JVP]]([[SELF]]) : $@convention(thin) (Float, @inout_aliasable InoutAliasableCapture) -> (Float, @owned @callee_guaranteed (Float) -> Float)
33+
// CHECK: [[VJP:%.*]] = function_ref @{{.*}}capturesMutableSelf{{.*}}__vjp_src_0_wrt_0 : $@convention(thin) (Float, @inout_aliasable InoutAliasableCapture) -> (Float, @owned @callee_guaranteed (Float) -> Float)
34+
// CHECK-NOT: retain_value_addr [[SELF]]
35+
// CHECK-NOT: copy_addr [[SELF]]
36+
// CHECK: [[VJP_CAPTURED:%.*]] = partial_apply [callee_guaranteed] [[VJP]]([[SELF]]) : $@convention(thin) (Float, @inout_aliasable InoutAliasableCapture) -> (Float, @owned @callee_guaranteed (Float) -> Float)
1637

1738
public func closureCaptureMutable() {
1839
var val: Float = 10

test/AutoDiff/leakchecking.swift

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,21 @@ LeakCheckingTests.testWithLeakChecking("ClosureCaptureLeakChecking") {
207207
return model.applied(to: x)
208208
}
209209
}
210+
211+
do {
212+
struct Foo {
213+
let x: Tracked<Float> = .zero
214+
var y: Tracked<Float> = .zero
215+
mutating func differentiateSomethingThatCapturesSelf() {
216+
_ = x.gradient { x in
217+
self.y += .zero
218+
return .zero
219+
}
220+
}
221+
}
222+
var foo = Foo()
223+
foo.differentiateSomethingThatCapturesSelf()
224+
}
210225
}
211226

212227
LeakCheckingTests.testWithLeakChecking("ControlFlowWithTrivialUnconditionalMath") {

0 commit comments

Comments
 (0)