Skip to content

[AutoDiff] Handle @inout_aliasable when converting associated functions #26484

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 6, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 27 additions & 8 deletions lib/SILOptimizer/Mandatory/Differentiation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1950,21 +1950,40 @@ reapplyFunctionConversion(SILValue newFunc, SILValue oldFunc,
for (auto *alloc : reversed(copiedIndirectParams))
builder.createDeallocStack(loc, alloc);
};
for (auto arg : pai->getArguments()) {
// Retain the argument since it's to be owned by the newly created
// Collect new arguments to for a new `partial_apply`.
auto conv = pai->getSubstCalleeConv();
unsigned argIndex = conv.getNumSILArguments() - pai->getNumArguments();
for (auto argIt = pai->getArguments().begin();
argIt != pai->getArguments().end(); ++argIt, ++argIndex) {
auto arg = *argIt;
// Retain the argument if it's to be owned by the newly created
// closure.
// Objects are to be retained.
if (arg->getType().isObject()) {
builder.createRetainValue(loc, arg, builder.getDefaultAtomicity());
newArgs.push_back(arg);
} else if (arg->getType().isLoadable(builder.getFunction())) {
continue;
}
// Addresses depend on argument conventions.
// If the argument is an aliasable inout reference, do not retain the
// argument since it's a `@noescape` capture.
auto argConv = conv.getSILArgumentConvention(argIndex);
if (argConv == SILArgumentConvention::Indirect_InoutAliasable) {
newArgs.push_back(arg);
continue;
}
// If it's a loadable address, perform a `retain_value_addr`.
if (arg->getType().isLoadable(builder.getFunction())) {
builder.createRetainValueAddr(loc, arg, builder.getDefaultAtomicity());
newArgs.push_back(arg);
} else {
auto *argCopy = builder.createAllocStack(loc, arg->getType());
copiedIndirectParams.push_back(argCopy);
builder.createCopyAddr(loc, arg, argCopy, IsNotTake, IsInitialization);
newArgs.push_back(argCopy);
continue;
}
// Otherwise, it must be address-only. Create a new buffer and perform
// `copy_addr`.
auto *argCopy = builder.createAllocStack(loc, arg->getType());
copiedIndirectParams.push_back(argCopy);
builder.createCopyAddr(loc, arg, argCopy, IsNotTake, IsInitialization);
newArgs.push_back(argCopy);
}
auto innerNewFunc = reapplyFunctionConversion(
newFunc, oldFunc, pai->getCallee(), builder, loc, newFuncGenSig);
Expand Down
21 changes: 21 additions & 0 deletions test/AutoDiff/closures.swift
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,27 @@ func diffableClosureInStruct(s: Foo) {
// CHECK: retain_value [[CLOSURE]] : $@differentiable @callee_guaranteed (Float) -> Float
// CHECK: autodiff_function_extract [original] [[CLOSURE]] : $@differentiable @callee_guaranteed (Float) -> Float

struct InoutAliasableCapture {
var x: Float = .zero
mutating func foo() {
func capturesMutableSelf(t: Float) -> Float {
self.x = .zero
return t
}
_ = gradient(at: .zero, in: capturesMutableSelf)
}
}

// CHECK-LABEL: @{{.*}}InoutAliasableCapture{{.*}}foo{{.*}} : $@convention(method) (@inout InoutAliasableCapture) -> () {
// CHECK: bb0([[SELF:%.*]] : $*InoutAliasableCapture):
// CHECK: [[JVP:%.*]] = function_ref @{{.*}}capturesMutableSelf{{.*}}__jvp_src_0_wrt_0 : $@convention(thin) (Float, @inout_aliasable InoutAliasableCapture) -> (Float, @owned @callee_guaranteed (Float) -> Float)
// CHECK-NOT: retain_value_addr [[SELF]]
// CHECK-NOT: copy_addr [[SELF]]
// CHECK: [[JVP_CAPTURED:%.*]] = partial_apply [callee_guaranteed] [[JVP]]([[SELF]]) : $@convention(thin) (Float, @inout_aliasable InoutAliasableCapture) -> (Float, @owned @callee_guaranteed (Float) -> Float)
// CHECK: [[VJP:%.*]] = function_ref @{{.*}}capturesMutableSelf{{.*}}__vjp_src_0_wrt_0 : $@convention(thin) (Float, @inout_aliasable InoutAliasableCapture) -> (Float, @owned @callee_guaranteed (Float) -> Float)
// CHECK-NOT: retain_value_addr [[SELF]]
// CHECK-NOT: copy_addr [[SELF]]
// CHECK: [[VJP_CAPTURED:%.*]] = partial_apply [callee_guaranteed] [[VJP]]([[SELF]]) : $@convention(thin) (Float, @inout_aliasable InoutAliasableCapture) -> (Float, @owned @callee_guaranteed (Float) -> Float)

public func closureCaptureMutable() {
var val: Float = 10
Expand Down
15 changes: 15 additions & 0 deletions test/AutoDiff/leakchecking.swift
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,21 @@ LeakCheckingTests.testWithLeakChecking("ClosureCaptureLeakChecking") {
return model.applied(to: x)
}
}

do {
struct Foo {
let x: Tracked<Float> = .zero
var y: Tracked<Float> = .zero
mutating func differentiateSomethingThatCapturesSelf() {
_ = x.gradient { x in
self.y += .zero
return .zero
}
}
}
var foo = Foo()
foo.differentiateSomethingThatCapturesSelf()
}
}

LeakCheckingTests.testWithLeakChecking("ControlFlowWithTrivialUnconditionalMath") {
Expand Down