Skip to content
This repository was archived by the owner on Jan 10, 2023. It is now read-only.

Commit 3fd4aae

Browse files
authored
Merge pull request swiftlang#34898 from gottesmm/pr-d3b6d903b097410b535d20940005a1fa67d3a1bf
[auto-diff] Fix a bunch of places in the *Cloners where we were not closing borrow scopes.
2 parents eb8af67 + affeba9 commit 3fd4aae

File tree

4 files changed

+37
-27
lines changed

4 files changed

+37
-27
lines changed

lib/SIL/IR/SILBuilder.cpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -637,15 +637,18 @@ DebugValueAddrInst *SILBuilder::createDebugValueAddr(SILLocation Loc,
637637

638638
void SILBuilder::emitScopedBorrowOperation(SILLocation loc, SILValue original,
639639
function_ref<void(SILValue)> &&fun) {
640-
if (original->getType().isAddress()) {
641-
original = createLoadBorrow(loc, original);
640+
SILValue value = original;
641+
if (value->getType().isAddress()) {
642+
value = createLoadBorrow(loc, value);
642643
} else {
643-
original = createBeginBorrow(loc, original);
644+
value = emitBeginBorrowOperation(loc, value);
644645
}
645646

646-
fun(original);
647+
fun(value);
647648

648-
createEndBorrow(loc, original);
649+
// If we actually inserted a borrowing operation... insert the end_borrow.
650+
if (value != original)
651+
createEndBorrow(loc, value);
649652
}
650653

651654
CheckedCastBranchInst *SILBuilder::createCheckedCastBranch(

lib/SILOptimizer/Differentiation/JVPCloner.cpp

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -519,11 +519,13 @@ class JVPCloner::Implementation final
519519
return;
520520
}
521521
}
522-
auto borrowedDiffFunc = builder.emitBeginBorrowOperation(loc, origCallee);
523-
jvpValue = builder.createDifferentiableFunctionExtract(
524-
loc, NormalDifferentiableFunctionTypeComponent::JVP,
525-
borrowedDiffFunc);
526-
jvpValue = builder.emitCopyValueOperation(loc, jvpValue);
522+
builder.emitScopedBorrowOperation(
523+
loc, origCallee, [&](SILValue borrowedDiffFunc) {
524+
jvpValue = builder.createDifferentiableFunctionExtract(
525+
loc, NormalDifferentiableFunctionTypeComponent::JVP,
526+
borrowedDiffFunc);
527+
jvpValue = builder.emitCopyValueOperation(loc, jvpValue);
528+
});
527529
}
528530

529531
// If JVP has not yet been found, emit an `differentiable_function`
@@ -614,11 +616,13 @@ class JVPCloner::Implementation final
614616
// Record the `differentiable_function` instruction.
615617
context.getDifferentiableFunctionInstWorklist().push_back(diffFuncInst);
616618

617-
auto borrowedADFunc = builder.emitBeginBorrowOperation(loc, diffFuncInst);
618-
auto extractedJVP = builder.createDifferentiableFunctionExtract(
619-
loc, NormalDifferentiableFunctionTypeComponent::JVP, borrowedADFunc);
620-
jvpValue = builder.emitCopyValueOperation(loc, extractedJVP);
621-
builder.emitEndBorrowOperation(loc, borrowedADFunc);
619+
builder.emitScopedBorrowOperation(
620+
loc, diffFuncInst, [&](SILValue borrowedADFunc) {
621+
auto extractedJVP = builder.createDifferentiableFunctionExtract(
622+
loc, NormalDifferentiableFunctionTypeComponent::JVP,
623+
borrowedADFunc);
624+
jvpValue = builder.emitCopyValueOperation(loc, extractedJVP);
625+
});
622626
builder.emitDestroyValueOperation(loc, diffFuncInst);
623627
}
624628

lib/SILOptimizer/Differentiation/VJPCloner.cpp

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -495,11 +495,13 @@ class VJPCloner::Implementation final
495495
loc, origCallee, SILType::getPrimitiveObjectType(origFnUnsubstType),
496496
/*withoutActuallyEscaping*/ false);
497497
}
498-
auto borrowedDiffFunc = builder.emitBeginBorrowOperation(loc, origCallee);
499-
vjpValue = builder.createDifferentiableFunctionExtract(
500-
loc, NormalDifferentiableFunctionTypeComponent::VJP,
501-
borrowedDiffFunc);
502-
vjpValue = builder.emitCopyValueOperation(loc, vjpValue);
498+
builder.emitScopedBorrowOperation(
499+
loc, origCallee, [&](SILValue borrowedDiffFunc) {
500+
vjpValue = builder.createDifferentiableFunctionExtract(
501+
loc, NormalDifferentiableFunctionTypeComponent::VJP,
502+
borrowedDiffFunc);
503+
vjpValue = builder.emitCopyValueOperation(loc, vjpValue);
504+
});
503505
auto vjpFnType = vjpValue->getType().castTo<SILFunctionType>();
504506
auto vjpFnUnsubstType = vjpFnType->getUnsubstitutedType(getModule());
505507
if (vjpFnType != vjpFnUnsubstType) {
@@ -601,11 +603,14 @@ class VJPCloner::Implementation final
601603
// Record the `differentiable_function` instruction.
602604
context.getDifferentiableFunctionInstWorklist().push_back(diffFuncInst);
603605

604-
auto borrowedADFunc = builder.emitBeginBorrowOperation(loc, diffFuncInst);
605-
auto extractedVJP = getBuilder().createDifferentiableFunctionExtract(
606-
loc, NormalDifferentiableFunctionTypeComponent::VJP, borrowedADFunc);
607-
vjpValue = builder.emitCopyValueOperation(loc, extractedVJP);
608-
builder.emitEndBorrowOperation(loc, borrowedADFunc);
606+
builder.emitScopedBorrowOperation(
607+
loc, diffFuncInst, [&](SILValue borrowedADFunc) {
608+
auto extractedVJP =
609+
getBuilder().createDifferentiableFunctionExtract(
610+
loc, NormalDifferentiableFunctionTypeComponent::VJP,
611+
borrowedADFunc);
612+
vjpValue = builder.emitCopyValueOperation(loc, extractedVJP);
613+
});
609614
builder.emitDestroyValueOperation(loc, diffFuncInst);
610615
}
611616

test/AutoDiff/SILOptimizer/derivative_sil.swift

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ func foo(_ x: Float) -> Float {
3333
// CHECK-SIL: [[ADD_VJP_REF:%.*]] = differentiability_witness_function [vjp] [parameters 0 1] [results 0] @add
3434
// CHECK-SIL: [[ADD_DIFF_FN:%.*]] = differentiable_function [parameters 0 1] [results 0] [[ADD_ORIG_REF]] : $@convention(method) (Float, Float, @thin Float.Type) -> Float with_derivative {[[ADD_JVP_REF]] : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float, Float) -> Float), [[ADD_VJP_REF]] : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float) -> (Float, Float))}
3535
// CHECK-SIL: [[ADD_JVP_FN:%.*]] = differentiable_function_extract [jvp] [[ADD_DIFF_FN]]
36-
// CHECK-SIL: end_borrow [[ADD_DIFF_FN]]
3736
// CHECK-SIL: [[ADD_RESULT:%.*]] = apply [[ADD_JVP_FN]]([[X]], [[X]], {{.*}})
3837
// CHECK-SIL: ([[ORIG_RES:%.*]], [[ADD_DF:%.*]]) = destructure_tuple [[ADD_RESULT]]
3938
// CHECK-SIL: [[DF_STRUCT:%.*]] = struct $_AD__foo_bb0__DF__src_0_wrt_0 ([[ADD_DF]] : $@callee_guaranteed (Float, Float) -> Float)
@@ -58,7 +57,6 @@ func foo(_ x: Float) -> Float {
5857
// CHECK-SIL: [[ADD_VJP_REF:%.*]] = differentiability_witness_function [vjp] [parameters 0 1] [results 0] @add
5958
// CHECK-SIL: [[ADD_DIFF_FN:%.*]] = differentiable_function [parameters 0 1] [results 0] [[ADD_ORIG_REF]] : $@convention(method) (Float, Float, @thin Float.Type) -> Float with_derivative {[[ADD_JVP_REF]] : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float, Float) -> Float), [[ADD_VJP_REF]] : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float) -> (Float, Float))}
6059
// CHECK-SIL: [[ADD_VJP_FN:%.*]] = differentiable_function_extract [vjp] [[ADD_DIFF_FN]]
61-
// CHECK-SIL: end_borrow [[ADD_DIFF_FN]]
6260
// CHECK-SIL: [[ADD_RESULT:%.*]] = apply [[ADD_VJP_FN]]([[X]], [[X]], {{.*}})
6361
// CHECK-SIL: ([[ORIG_RES:%.*]], [[ADD_PB:%.*]]) = destructure_tuple [[ADD_RESULT]]
6462
// CHECK-SIL: [[PB_STRUCT:%.*]] = struct $_AD__foo_bb0__PB__src_0_wrt_0 ([[ADD_PB]] : $@callee_guaranteed (Float) -> (Float, Float))

0 commit comments

Comments
 (0)