Skip to content

Commit 0641a22

Browse files
authored
[AutoDiff] Fix 'autodiff_function_extract' operand ownership kind and memory leaks. (#27199)
The `autodiff_function_extract` instruction behaves like `tuple_extract`, where it extracts some element from an aggregate. Its operand should have the same ownership kind as that of `tuple_extract`. That is, it should be defined as `CONSTANT_OR_TRIVIAL_OWNERSHIP_INST(Guaranteed, AutoDiffFunctionInst)` in ValueOwnershipKindClassifier. However, this is currently defined wrongly as `FORWARDING_OWNERSHIP_INST(AutoDiffFunctionExtract)`, which caused a bug in the differentiation transform to be uncaught: VJPEmitter and JVPEmitter in the differentiation transform is performing `autodiff_function_extract` on an `@owned` `@differentiable` function, which caused associated functions that are not extracted to be not released: ``` %f = autodiff_function %original %f_vjp = autodiff_function_extract [vjp] %f ... // %f is not released, and not caught by ownership verification! ``` After we fix the operand ownership kind for `autodiff_function_extract`, all these cases are now caught by ownership verification. The reproducer in [TF-795](https://bugs.swift.org/browse/TF-795) and most differentiation tests are failing to compile because ownership verification caught the bug in AD-generated code. The existing AD test suite is serving as good test cases for this ownership error. To fix this, VJPEmitter and JVPEmitter are now changed to emit borrows of `@differentiable` functions and copies of associated functions and property destroying the `@differentiable` function: ``` %f = autodiff_function %original %f_borrowed = begin_borrow %f %f_vjp_extracted = autodiff_function_extract [vjp] %f_borrowed %f_vjp = copy_value %f_vjp_extracted end_borrow %f_borrowed destroy_value %f ``` Fixes [TF-795](https://bugs.swift.org/browse/TF-795).
1 parent dd4fd33 commit 0641a22

File tree

2 files changed

+16
-7
lines changed

2 files changed

+16
-7
lines changed

lib/SIL/ValueOwnership.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,8 @@ CONSTANT_OWNERSHIP_INST(Unowned, ValueToBridgeObject)
163163
}
164164
CONSTANT_OR_TRIVIAL_OWNERSHIP_INST(Guaranteed, StructExtract)
165165
CONSTANT_OR_TRIVIAL_OWNERSHIP_INST(Guaranteed, TupleExtract)
166+
// SWIFT_ENABLE_TENSORFLOW
167+
CONSTANT_OR_TRIVIAL_OWNERSHIP_INST(Guaranteed, AutoDiffFunctionExtract)
166168
// OpenExistentialValue opens the boxed value inside an existential
167169
// CoW box. The semantics of an existential CoW box implies that we
168170
// can only consume the projected value inside the box if the box is
@@ -254,7 +256,6 @@ FORWARDING_OWNERSHIP_INST(SelectEnum)
254256
FORWARDING_OWNERSHIP_INST(Enum)
255257
// SWIFT_ENABLE_TENSORFLOW
256258
FORWARDING_OWNERSHIP_INST(AutoDiffFunction)
257-
FORWARDING_OWNERSHIP_INST(AutoDiffFunctionExtract)
258259
#undef FORWARDING_OWNERSHIP_INST
259260

260261
ValueOwnershipKind

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3749,9 +3749,14 @@ class VJPEmitter final
37493749
context.getResultIndices()[autoDiffFuncInst] =
37503750
activeResultIndices.front();
37513751

3752-
vjpValue = getBuilder().createAutoDiffFunctionExtract(
3752+
auto borrowedADFunc =
3753+
builder.emitBeginBorrowOperation(loc, autoDiffFuncInst);
3754+
auto extractedVJP = getBuilder().createAutoDiffFunctionExtract(
37533755
loc, AutoDiffFunctionExtractInst::Extractee::VJP,
3754-
/*differentiationOrder*/ 1, autoDiffFuncInst);
3756+
/*differentiationOrder*/ 1, borrowedADFunc);
3757+
vjpValue = builder.emitCopyValueOperation(loc, extractedVJP);
3758+
builder.emitEndBorrowOperation(loc, borrowedADFunc);
3759+
builder.emitDestroyValueOperation(loc, autoDiffFuncInst);
37553760
}
37563761

37573762
// Record desired/actual VJP indices.
@@ -4850,7 +4855,6 @@ class JVPEmitter final
48504855
// on the remapped original function operand and `autodiff_function_extract`
48514856
// the JVP. The actual JVP functions will be populated in the
48524857
// `autodiff_function` during the transform main loop.
4853-
SILValue differentiableFunc;
48544858
if (!jvpValue) {
48554859
// FIXME: Handle indirect differentiation invokers. This may require some
48564860
// redesign: currently, each original function + attribute pair is mapped
@@ -4911,16 +4915,20 @@ class JVPEmitter final
49114915
auto *autoDiffFuncInst =
49124916
context.createAutoDiffFunction(builder, loc, indices.parameters,
49134917
/*differentiationOrder*/ 1, original);
4914-
differentiableFunc = autoDiffFuncInst;
49154918

49164919
// Record the `autodiff_function` instruction.
49174920
context.getAutoDiffFunctionInsts().push_back(autoDiffFuncInst);
49184921
context.getResultIndices()[autoDiffFuncInst] =
49194922
activeResultIndices.front();
49204923

4921-
jvpValue = builder.createAutoDiffFunctionExtract(
4924+
auto borrowedADFunc =
4925+
builder.emitBeginBorrowOperation(loc, autoDiffFuncInst);
4926+
auto extractedJVP = builder.createAutoDiffFunctionExtract(
49224927
loc, AutoDiffFunctionExtractInst::Extractee::JVP,
4923-
/*differentiationOrder*/ 1, autoDiffFuncInst);
4928+
/*differentiationOrder*/ 1, borrowedADFunc);
4929+
jvpValue = builder.emitCopyValueOperation(loc, extractedJVP);
4930+
builder.emitEndBorrowOperation(loc, borrowedADFunc);
4931+
builder.emitDestroyValueOperation(loc, autoDiffFuncInst);
49244932
}
49254933

49264934
// Call the JVP using the original parameters.

0 commit comments

Comments
 (0)