Skip to content

[AutoDiff] fix ownership instructions (SR-13973) #35196

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions lib/SILOptimizer/Mandatory/Differentiation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,8 @@ static void copyParameterArgumentsForApply(
// Objects are to be retained.
if (arg->getType().isObject()) {
auto newArg = arg;
if (newArg.getOwnershipKind() != OwnershipKind::None)
if (!copyBuilder.hasOwnership() ||
newArg.getOwnershipKind() != OwnershipKind::None)
newArg = copyBuilder.emitCopyValueOperation(loc, arg);
collectNewArg(newArg);
continue;
Expand Down Expand Up @@ -500,7 +501,8 @@ emitDerivativeFunctionReference(
builder.emitBeginBorrowOperation(original.getLoc(), original);
SILValue derivativeFn = builder.createDifferentiableFunctionExtract(
borrowedDiffFunc.getLoc(), kind, borrowedDiffFunc);
if (derivativeFn.getOwnershipKind() != OwnershipKind::None)
if (!builder.hasOwnership() ||
derivativeFn.getOwnershipKind() != OwnershipKind::None)
derivativeFn =
builder.emitCopyValueOperation(original.getLoc(), derivativeFn);
builder.emitEndBorrowOperation(original.getLoc(), borrowedDiffFunc);
Expand Down Expand Up @@ -867,7 +869,8 @@ static void emitFatalError(ADContext &context, SILFunction *f,
auto loc = f->getLocation();
// Destroy all owned arguments to pass ownership verification.
for (auto *arg : entry->getArguments())
if (arg->getOwnershipKind() == OwnershipKind::Owned)
if (!builder.hasOwnership() ||
arg->getOwnershipKind() == OwnershipKind::Owned)
builder.emitDestroyOperation(loc, arg);
// Fatal error with a nice message.
auto neverResultInfo =
Expand Down Expand Up @@ -1213,7 +1216,8 @@ SILValue DifferentiationTransformer::promoteToDifferentiableFunction(
builder.createDeallocStack(loc, buf);

// If our original copy does not have none ownership, copy it.
if (origFnOperand.getOwnershipKind() != OwnershipKind::None)
if (!builder.hasOwnership() ||
origFnOperand.getOwnershipKind() != OwnershipKind::None)
origFnOperand = builder.emitCopyValueOperation(loc, origFnOperand);
auto *newDiffFn = context.createDifferentiableFunction(
builder, loc, parameterIndices, resultIndices, origFnOperand,
Expand All @@ -1229,7 +1233,8 @@ SILValue DifferentiationTransformer::promoteToLinearFunction(
// with an undef transpose function operand. Eventually, a legitimate
// transpose function operand should be created and used.
auto origFnOperand = lfi->getOriginalFunction();
if (origFnOperand.getOwnershipKind() != OwnershipKind::None)
if (!builder.hasOwnership() ||
origFnOperand.getOwnershipKind() != OwnershipKind::None)
origFnOperand = builder.emitCopyValueOperation(loc, origFnOperand);
auto *parameterIndices = lfi->getParameterIndices();
auto originalType = origFnOperand->getType().castTo<SILFunctionType>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,3 +139,50 @@ bb0(%0 : $Class):
// CHECK: [[VJP_FN_PARTIALLY_APPLIED:%.*]] = partial_apply [callee_guaranteed] [[VJP_FN]]([[ARG]])
// CHECK: differentiable_function [parameters 0] [results 0] [[ORIG_FN_PARTIALLY_APPLIED]] : {{.*}} with_derivative {[[JVP_FN_PARTIALLY_APPLIED]] : {{.*}}, [[VJP_FN_PARTIALLY_APPLIED]] : {{.*}}}
// CHECK: }

// Test curry thunks.

struct S {
var x: Float
}

sil @method : $@convention(method) (Float, S) -> Float {
bb0(%0 : $Float, %1 : $S):
return %0 : $Float
}

sil @method_thin : $@convention(thin) (Float, S) -> Float {
bb0(%0 : $Float, %1 : $S):
%4 = function_ref @method : $@convention(method) (Float, S) -> Float
%5 = apply %4(%0, %1) : $@convention(method) (Float, S) -> Float
return %5 : $Float
}

sil @method_curried : $@convention(thin) (S) -> @owned @callee_guaranteed (Float) -> Float {
bb0(%0 : $S):
%2 = function_ref @method_thin : $@convention(thin) (Float, S) -> Float
%3 = partial_apply [callee_guaranteed] %2(%0) : $@convention(thin) (Float, S) -> Float
return %3 : $@callee_guaranteed (Float) -> Float
}

sil @test_curry_thunk : $@convention(thin) (Float, S) -> () {
bb0(%0 : $Float, %1 : $S):
%2 = function_ref @method_curried : $@convention(thin) (S) -> @owned @callee_guaranteed (Float) -> Float
%3 = apply %2(%1) : $@convention(thin) (S) -> @owned @callee_guaranteed (Float) -> Float
%4 = differentiable_function [parameters 0] [results 0] %3 : $@callee_guaranteed (Float) -> Float
%5 = tuple ()
return %5 : $()
}

// CHECK-LABEL: sil {{.*}} @AD__method_curried__differentiable_curry_thunk_src_0_wrt_0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a ':' at the end of this to prevent prefix double matches?

Can you also add an } // end sil function '$NAME' to the end.

This is just basic SIL test hygiene.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

// CHECK: bb0([[SELF:%.*]] : $S):
// CHECK: [[METHOD:%.*]] = function_ref @method_thin : $@convention(thin) (Float, S) -> Float
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add an ossa test? Doesn't this transform only run on OSSA? I guess not?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does that mean add a test where the sil functions say [ossa]? Done.

I guess this transform doesn't only run on OSSA...

// CHECK: [[CURRIED:%.*]] = partial_apply [callee_guaranteed] [[METHOD]]([[SELF]]) : $@convention(thin) (Float, S) -> Float
// CHECK: [[METHOD_JVP:%.*]] = differentiability_witness_function [jvp] [parameters 0] [results 0] @method_thin : $@convention(thin) (Float, S) -> Float
// CHECK: [[CURRIED_JVP:%.*]] = partial_apply [callee_guaranteed] [[METHOD_JVP]]([[SELF]]) : $@convention(thin) (Float, S) -> (Float, @owned @callee_guaranteed (Float) -> Float)
// CHECK: [[METHOD_VJP:%.*]] = differentiability_witness_function [vjp] [parameters 0] [results 0] @method_thin : $@convention(thin) (Float, S) -> Float
// CHECK: [[CURRIED_VJP:%.*]] = partial_apply [callee_guaranteed] [[METHOD_VJP]]([[SELF]]) : $@convention(thin) (Float, S) -> (Float, @owned @callee_guaranteed (Float) -> Float)
// CHECK: strong_retain [[CURRIED]] : $@callee_guaranteed (Float) -> Float
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This strong_retain gets added by this PR.

// CHECK: [[RESULT:%.*]] = differentiable_function [parameters 0] [results 0] [[CURRIED]] : $@callee_guaranteed (Float) -> Float with_derivative {[[CURRIED_JVP]] : $@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float), [[CURRIED_VJP]] : $@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)}
// CHECK: strong_release [[CURRIED]] : $@callee_guaranteed (Float) -> Float
// CHECK: return [[RESULT]] : $@differentiable @callee_guaranteed (Float) -> Float
24 changes: 24 additions & 0 deletions test/AutoDiff/validation-test/address_sanitizer.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// RUN: %empty-directory(%t)
// RUN: %target-build-swift %s -sanitize=address -o %t/sr13973
// RUN: %target-run %t/sr13973

// REQUIRES: executable_test
// REQUIRES: asan_runtime

import _Differentiation

struct SR13973 {
let x: Float = 0

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can I have a SIL test?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For changes like this, both are important!

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a FileCheck to test/AutoDiff/SILOptimizer/differentiation_function_canonicalization.sil. It's pretty much the same thing that I do in this file, except in SIL.

@differentiable
func errorVector(_ t: Float) -> Float {
return t
}
}

func sr13973() {
let s = SR13973()
_ = valueWithPullback(at: 0, in: s.errorVector)
}

sr13973()