Skip to content

Commit 711591e

Browse files
dan-zhengbgogul
authored andcommitted
[AutoDiff] Destroy unused pullback direct results. (#28207)
In `PullbackEmitter::visitApplyInst`, destroy unused pullback direct results. This is needed for VJPs extracted from `@differentiable` function callees, where the `@differentiable` function's differentiation parameter indices are a superset of the active `apply` parameter indices. Resolves TF-953.
1 parent a58b1df commit 711591e

File tree

2 files changed

+15
-6
lines changed

2 files changed

+15
-6
lines changed

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6923,8 +6923,8 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
69236923
? *applyInfo.originalPullbackType
69246924
: pullbackType;
69256925
for (auto indRes : actualPullbackType->getIndirectFormalResults()) {
6926-
auto *alloc =
6927-
builder.createAllocStack(loc, remapType(indRes.getSILStorageInterfaceType()));
6926+
auto *alloc = builder.createAllocStack(
6927+
loc, remapType(indRes.getSILStorageInterfaceType()));
69286928
pullbackIndirectResults.push_back(alloc);
69296929
args.push_back(alloc);
69306930
}
@@ -6975,13 +6975,22 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
69756975
addToAdjointBuffer(bb, origArg, tmpBuf, loc);
69766976
builder.emitDestroyAddrAndFold(loc, tmpBuf);
69776977
builder.createDeallocStack(loc, tmpBuf);
6978-
}
6979-
else {
6978+
} else {
69806979
recordTemporary(tan);
69816980
addAdjointValue(bb, origArg, makeConcreteAdjointValue(tan), loc);
69826981
}
69836982
}
69846983
}
6984+
// Destroy unused pullback direct results. Needed for pullback results from
6985+
// VJPs extracted from `@differentiable` function callees, where the
6986+
// `@differentiable` function's differentiation parameter indices are a
6987+
// superset of the active `apply` parameter indices.
6988+
while (allResultsIt != allResults.end()) {
6989+
auto unusedPullbackDirectResult = *allResultsIt++;
6990+
if (unusedPullbackDirectResult->getType().isAddress())
6991+
continue;
6992+
builder.emitDestroyValueOperation(loc, unusedPullbackDirectResult);
6993+
}
69856994
// Destroy and deallocate pullback indirect results.
69866995
for (auto *alloc : llvm::reverse(pullbackIndirectResults)) {
69876996
builder.emitDestroyAddrAndFold(loc, alloc);

test/AutoDiff/compiler_crashers/tf928-pullback-ownership-memory-leak.swift renamed to test/AutoDiff/compiler_crashers_fixed/tf928-pullback-ownership-memory-leak.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
// RUN: not --crash %target-swift-emit-sil %s
2-
// REQUIRES: asserts
1+
// RUN: %target-swift-emit-sil %s
32

43
// TF-928: Ownership verification error in pullback function generated by the
54
// differentiation transform.
@@ -27,6 +26,7 @@ func TF_928(
2726
_ x: Tracked<Float>
2827
) {
2928
_ = pullback(at: x) { x in lossFunction(x, Tracked<Float>()) }
29+
_ = pullback(at: x) { x in lossFunction(Tracked<Float>(), x) }
3030
}
3131

3232
// Function: 'AD__$s4main6TF_928yyAA7TrackedVySfGAE_AEtXF_AEtFA2EcfU___pullback_src_0_wrt_0'

0 commit comments

Comments
 (0)