Skip to content

Commit 6773a69

Browse files
authored
[AutoDiff] Handle @inout_aliasable when converting associated functions (#26484)
When reapplying a `partial_apply` for a JVP/VJP, if the argument has convention `@inout_aliasable`, it is a noescape mutable capture and the underlying value should not be retained. This fixes a memory leak when a noescape closure being differentiated captures a mutable self. Resolves eaplatanios/swift-ale#1.
1 parent 9c95e27 commit 6773a69

File tree

3 files changed

+63
-8
lines changed

3 files changed

+63
-8
lines changed

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1947,21 +1947,40 @@ reapplyFunctionConversion(SILValue newFunc, SILValue oldFunc,
19471947
for (auto *alloc : reversed(copiedIndirectParams))
19481948
builder.createDeallocStack(loc, alloc);
19491949
};
1950-
for (auto arg : pai->getArguments()) {
1951-
// Retain the argument since it's to be owned by the newly created
1950+
// Collect new arguments to for a new `partial_apply`.
1951+
auto conv = pai->getSubstCalleeConv();
1952+
unsigned argIndex = conv.getNumSILArguments() - pai->getNumArguments();
1953+
for (auto argIt = pai->getArguments().begin();
1954+
argIt != pai->getArguments().end(); ++argIt, ++argIndex) {
1955+
auto arg = *argIt;
1956+
// Retain the argument if it's to be owned by the newly created
19521957
// closure.
1958+
// Objects are to be retained.
19531959
if (arg->getType().isObject()) {
19541960
builder.createRetainValue(loc, arg, builder.getDefaultAtomicity());
19551961
newArgs.push_back(arg);
1956-
} else if (arg->getType().isLoadable(builder.getFunction())) {
1962+
continue;
1963+
}
1964+
// Addresses depend on argument conventions.
1965+
// If the argument is an aliasable inout reference, do not retain the
1966+
// argument since it's a `@noescape` capture.
1967+
auto argConv = conv.getSILArgumentConvention(argIndex);
1968+
if (argConv == SILArgumentConvention::Indirect_InoutAliasable) {
1969+
newArgs.push_back(arg);
1970+
continue;
1971+
}
1972+
// If it's a loadable address, perform a `retain_value_addr`.
1973+
if (arg->getType().isLoadable(builder.getFunction())) {
19571974
builder.createRetainValueAddr(loc, arg, builder.getDefaultAtomicity());
19581975
newArgs.push_back(arg);
1959-
} else {
1960-
auto *argCopy = builder.createAllocStack(loc, arg->getType());
1961-
copiedIndirectParams.push_back(argCopy);
1962-
builder.createCopyAddr(loc, arg, argCopy, IsNotTake, IsInitialization);
1963-
newArgs.push_back(argCopy);
1976+
continue;
19641977
}
1978+
// Otherwise, it must be address-only. Create a new buffer and perform
1979+
// `copy_addr`.
1980+
auto *argCopy = builder.createAllocStack(loc, arg->getType());
1981+
copiedIndirectParams.push_back(argCopy);
1982+
builder.createCopyAddr(loc, arg, argCopy, IsNotTake, IsInitialization);
1983+
newArgs.push_back(argCopy);
19651984
}
19661985
auto innerNewFunc = reapplyFunctionConversion(
19671986
newFunc, oldFunc, pai->getCallee(), builder, loc, newFuncGenSig);

test/AutoDiff/closures.swift

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,27 @@ func diffableClosureInStruct(s: Foo) {
1313
// CHECK: retain_value [[CLOSURE]] : $@differentiable @callee_guaranteed (Float) -> Float
1414
// CHECK: autodiff_function_extract [original] [[CLOSURE]] : $@differentiable @callee_guaranteed (Float) -> Float
1515

16+
struct InoutAliasableCapture {
17+
var x: Float = .zero
18+
mutating func foo() {
19+
func capturesMutableSelf(t: Float) -> Float {
20+
self.x = .zero
21+
return t
22+
}
23+
_ = gradient(at: .zero, in: capturesMutableSelf)
24+
}
25+
}
26+
27+
// CHECK-LABEL: @{{.*}}InoutAliasableCapture{{.*}}foo{{.*}} : $@convention(method) (@inout InoutAliasableCapture) -> () {
28+
// CHECK: bb0([[SELF:%.*]] : $*InoutAliasableCapture):
29+
// CHECK: [[JVP:%.*]] = function_ref @{{.*}}capturesMutableSelf{{.*}}__jvp_src_0_wrt_0 : $@convention(thin) (Float, @inout_aliasable InoutAliasableCapture) -> (Float, @owned @callee_guaranteed (Float) -> Float)
30+
// CHECK-NOT: retain_value_addr [[SELF]]
31+
// CHECK-NOT: copy_addr [[SELF]]
32+
// CHECK: [[JVP_CAPTURED:%.*]] = partial_apply [callee_guaranteed] [[JVP]]([[SELF]]) : $@convention(thin) (Float, @inout_aliasable InoutAliasableCapture) -> (Float, @owned @callee_guaranteed (Float) -> Float)
33+
// CHECK: [[VJP:%.*]] = function_ref @{{.*}}capturesMutableSelf{{.*}}__vjp_src_0_wrt_0 : $@convention(thin) (Float, @inout_aliasable InoutAliasableCapture) -> (Float, @owned @callee_guaranteed (Float) -> Float)
34+
// CHECK-NOT: retain_value_addr [[SELF]]
35+
// CHECK-NOT: copy_addr [[SELF]]
36+
// CHECK: [[VJP_CAPTURED:%.*]] = partial_apply [callee_guaranteed] [[VJP]]([[SELF]]) : $@convention(thin) (Float, @inout_aliasable InoutAliasableCapture) -> (Float, @owned @callee_guaranteed (Float) -> Float)
1637

1738
public func closureCaptureMutable() {
1839
var val: Float = 10

test/AutoDiff/leakchecking.swift

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,21 @@ LeakCheckingTests.testWithLeakChecking("ClosureCaptureLeakChecking") {
418418
return model.applied(to: x)
419419
}
420420
}
421+
422+
do {
423+
struct Foo {
424+
let x: Tracked<Float> = .zero
425+
var y: Tracked<Float> = .zero
426+
mutating func differentiateSomethingThatCapturesSelf() {
427+
_ = x.gradient { x in
428+
self.y += .zero
429+
return .zero
430+
}
431+
}
432+
}
433+
var foo = Foo()
434+
foo.differentiateSomethingThatCapturesSelf()
435+
}
421436
}
422437

423438
LeakCheckingTests.testWithLeakChecking("ControlFlowWithTrivialUnconditionalMath") {

0 commit comments

Comments
 (0)