Skip to content

Commit affeba9

Browse files
committed
[auto-diff] Fix a bunch of places in the *Cloners where we were not closing borrow scopes.
These were all just trying to open a borrow scope, so I changed them to use the API SILBuilder::emitScopedBorrowOperation(SILLocation, SIL).
1 parent 0d728ee commit affeba9

File tree

4 files changed

+37
-27
lines changed

4 files changed

+37
-27
lines changed

lib/SIL/IR/SILBuilder.cpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -637,15 +637,18 @@ DebugValueAddrInst *SILBuilder::createDebugValueAddr(SILLocation Loc,
637637

638638
void SILBuilder::emitScopedBorrowOperation(SILLocation loc, SILValue original,
639639
function_ref<void(SILValue)> &&fun) {
640-
if (original->getType().isAddress()) {
641-
original = createLoadBorrow(loc, original);
640+
SILValue value = original;
641+
if (value->getType().isAddress()) {
642+
value = createLoadBorrow(loc, value);
642643
} else {
643-
original = createBeginBorrow(loc, original);
644+
value = emitBeginBorrowOperation(loc, value);
644645
}
645646

646-
fun(original);
647+
fun(value);
647648

648-
createEndBorrow(loc, original);
649+
// If we actually inserted a borrowing operation... insert the end_borrow.
650+
if (value != original)
651+
createEndBorrow(loc, value);
649652
}
650653

651654
CheckedCastBranchInst *SILBuilder::createCheckedCastBranch(

lib/SILOptimizer/Differentiation/JVPCloner.cpp

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -514,11 +514,13 @@ class JVPCloner::Implementation final
514514
return;
515515
}
516516
}
517-
auto borrowedDiffFunc = builder.emitBeginBorrowOperation(loc, origCallee);
518-
jvpValue = builder.createDifferentiableFunctionExtract(
519-
loc, NormalDifferentiableFunctionTypeComponent::JVP,
520-
borrowedDiffFunc);
521-
jvpValue = builder.emitCopyValueOperation(loc, jvpValue);
517+
builder.emitScopedBorrowOperation(
518+
loc, origCallee, [&](SILValue borrowedDiffFunc) {
519+
jvpValue = builder.createDifferentiableFunctionExtract(
520+
loc, NormalDifferentiableFunctionTypeComponent::JVP,
521+
borrowedDiffFunc);
522+
jvpValue = builder.emitCopyValueOperation(loc, jvpValue);
523+
});
522524
}
523525

524526
// If JVP has not yet been found, emit an `differentiable_function`
@@ -609,11 +611,13 @@ class JVPCloner::Implementation final
609611
// Record the `differentiable_function` instruction.
610612
context.getDifferentiableFunctionInstWorklist().push_back(diffFuncInst);
611613

612-
auto borrowedADFunc = builder.emitBeginBorrowOperation(loc, diffFuncInst);
613-
auto extractedJVP = builder.createDifferentiableFunctionExtract(
614-
loc, NormalDifferentiableFunctionTypeComponent::JVP, borrowedADFunc);
615-
jvpValue = builder.emitCopyValueOperation(loc, extractedJVP);
616-
builder.emitEndBorrowOperation(loc, borrowedADFunc);
614+
builder.emitScopedBorrowOperation(
615+
loc, diffFuncInst, [&](SILValue borrowedADFunc) {
616+
auto extractedJVP = builder.createDifferentiableFunctionExtract(
617+
loc, NormalDifferentiableFunctionTypeComponent::JVP,
618+
borrowedADFunc);
619+
jvpValue = builder.emitCopyValueOperation(loc, extractedJVP);
620+
});
617621
builder.emitDestroyValueOperation(loc, diffFuncInst);
618622
}
619623

lib/SILOptimizer/Differentiation/VJPCloner.cpp

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -434,11 +434,13 @@ class VJPCloner::Implementation final
434434
loc, origCallee, SILType::getPrimitiveObjectType(origFnUnsubstType),
435435
/*withoutActuallyEscaping*/ false);
436436
}
437-
auto borrowedDiffFunc = builder.emitBeginBorrowOperation(loc, origCallee);
438-
vjpValue = builder.createDifferentiableFunctionExtract(
439-
loc, NormalDifferentiableFunctionTypeComponent::VJP,
440-
borrowedDiffFunc);
441-
vjpValue = builder.emitCopyValueOperation(loc, vjpValue);
437+
builder.emitScopedBorrowOperation(
438+
loc, origCallee, [&](SILValue borrowedDiffFunc) {
439+
vjpValue = builder.createDifferentiableFunctionExtract(
440+
loc, NormalDifferentiableFunctionTypeComponent::VJP,
441+
borrowedDiffFunc);
442+
vjpValue = builder.emitCopyValueOperation(loc, vjpValue);
443+
});
442444
auto vjpFnType = vjpValue->getType().castTo<SILFunctionType>();
443445
auto vjpFnUnsubstType = vjpFnType->getUnsubstitutedType(getModule());
444446
if (vjpFnType != vjpFnUnsubstType) {
@@ -540,11 +542,14 @@ class VJPCloner::Implementation final
540542
// Record the `differentiable_function` instruction.
541543
context.getDifferentiableFunctionInstWorklist().push_back(diffFuncInst);
542544

543-
auto borrowedADFunc = builder.emitBeginBorrowOperation(loc, diffFuncInst);
544-
auto extractedVJP = getBuilder().createDifferentiableFunctionExtract(
545-
loc, NormalDifferentiableFunctionTypeComponent::VJP, borrowedADFunc);
546-
vjpValue = builder.emitCopyValueOperation(loc, extractedVJP);
547-
builder.emitEndBorrowOperation(loc, borrowedADFunc);
545+
builder.emitScopedBorrowOperation(
546+
loc, diffFuncInst, [&](SILValue borrowedADFunc) {
547+
auto extractedVJP =
548+
getBuilder().createDifferentiableFunctionExtract(
549+
loc, NormalDifferentiableFunctionTypeComponent::VJP,
550+
borrowedADFunc);
551+
vjpValue = builder.emitCopyValueOperation(loc, extractedVJP);
552+
});
548553
builder.emitDestroyValueOperation(loc, diffFuncInst);
549554
}
550555

test/AutoDiff/SILOptimizer/derivative_sil.swift

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ func foo(_ x: Float) -> Float {
3333
// CHECK-SIL: [[ADD_VJP_REF:%.*]] = differentiability_witness_function [vjp] [parameters 0 1] [results 0] @add
3434
// CHECK-SIL: [[ADD_DIFF_FN:%.*]] = differentiable_function [parameters 0 1] [results 0] [[ADD_ORIG_REF]] : $@convention(method) (Float, Float, @thin Float.Type) -> Float with_derivative {[[ADD_JVP_REF]] : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float, Float) -> Float), [[ADD_VJP_REF]] : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float) -> (Float, Float))}
3535
// CHECK-SIL: [[ADD_JVP_FN:%.*]] = differentiable_function_extract [jvp] [[ADD_DIFF_FN]]
36-
// CHECK-SIL: end_borrow [[ADD_DIFF_FN]]
3736
// CHECK-SIL: [[ADD_RESULT:%.*]] = apply [[ADD_JVP_FN]]([[X]], [[X]], {{.*}})
3837
// CHECK-SIL: ([[ORIG_RES:%.*]], [[ADD_DF:%.*]]) = destructure_tuple [[ADD_RESULT]]
3938
// CHECK-SIL: [[DF_STRUCT:%.*]] = struct $_AD__foo_bb0__DF__src_0_wrt_0 ([[ADD_DF]] : $@callee_guaranteed (Float, Float) -> Float)
@@ -58,7 +57,6 @@ func foo(_ x: Float) -> Float {
5857
// CHECK-SIL: [[ADD_VJP_REF:%.*]] = differentiability_witness_function [vjp] [parameters 0 1] [results 0] @add
5958
// CHECK-SIL: [[ADD_DIFF_FN:%.*]] = differentiable_function [parameters 0 1] [results 0] [[ADD_ORIG_REF]] : $@convention(method) (Float, Float, @thin Float.Type) -> Float with_derivative {[[ADD_JVP_REF]] : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float, Float) -> Float), [[ADD_VJP_REF]] : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float) -> (Float, Float))}
6059
// CHECK-SIL: [[ADD_VJP_FN:%.*]] = differentiable_function_extract [vjp] [[ADD_DIFF_FN]]
61-
// CHECK-SIL: end_borrow [[ADD_DIFF_FN]]
6260
// CHECK-SIL: [[ADD_RESULT:%.*]] = apply [[ADD_VJP_FN]]([[X]], [[X]], {{.*}})
6361
// CHECK-SIL: ([[ORIG_RES:%.*]], [[ADD_PB:%.*]]) = destructure_tuple [[ADD_RESULT]]
6462
// CHECK-SIL: [[PB_STRUCT:%.*]] = struct $_AD__foo_bb0__PB__src_0_wrt_0 ([[ADD_PB]] : $@callee_guaranteed (Float) -> (Float, Float))

0 commit comments

Comments
 (0)