Skip to content

Commit 0d17ddf

Browse files
authored
[AutoDiff] Destroy all pullback indirect results after adjoint accumulation. (#27711)
When we differentiate a function (example below) with respect to a proper subset of its indirect parameters and when the function only has a derivative with respect to a proper superset of those indirect parameters, the pullback returns more indirect results than we need. However, unneeded indirect results are not destroyed, which causes a memory lifetime verification failure. This patch fixes this bug by releasing all pullback indirect results instead of just releasing the ones needed for calculating the derivative. ```swift @differentiable(wrt: x) func foo<T: Differentiable>(_ x: T, _ y: T, apply: @differentiable (T, T) -> T) -> T { return apply(x, y) } ``` This patch also uncomments a test in test/AutoDiff/superset_adjoint.swift which is now passing. This fixed a FIXME. Resolves [TF-914](https://bugs.swift.org/browse/TF-914).
1 parent bafacd8 commit 0d17ddf

File tree

2 files changed

+26
-21
lines changed

2 files changed

+26
-21
lines changed

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6955,7 +6955,6 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
69556955
auto tan = *allResultsIt++;
69566956
if (tan->getType().isAddress()) {
69576957
addToAdjointBuffer(bb, origArg, tan, loc);
6958-
builder.emitDestroyAddrAndFold(loc, tan);
69596958
} else {
69606959
if (origArg->getType().isAddress()) {
69616960
auto *tmpBuf = builder.createAllocStack(loc, tan->getType());
@@ -6971,9 +6970,11 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
69716970
}
69726971
}
69736972
}
6974-
// Deallocate pullback indirect results.
6975-
for (auto *alloc : reversed(pullbackIndirectResults))
6973+
// Destroy and deallocate pullback indirect results.
6974+
for (auto *alloc : reversed(pullbackIndirectResults)) {
6975+
builder.emitDestroyAddrAndFold(loc, alloc);
69766976
builder.createDeallocStack(loc, alloc);
6977+
}
69776978
}
69786979

69796980
/// Handle `struct` instruction.

test/AutoDiff/superset_adjoint.swift

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -41,30 +41,34 @@ SupersetVJPTests.test("SubsetOfSubset") {
4141
expectEqual(0, gradient(at: 0, in: { x in foo(x, 0, 0) }))
4242
}
4343

44+
SupersetVJPTests.test("ApplySubset") {
45+
// TF-914
46+
@differentiable(wrt: x)
47+
func foo<T: Differentiable>(_ x: T, _ y: T, apply: @differentiable (T, T) -> T) -> T {
48+
return apply(x, y)
49+
}
50+
expectEqual(1, gradient(at: Float(0)) { x in foo(x, 0) { $0 + $1 } })
51+
}
52+
4453
// FIXME: The expression `(+) as @differentiable (Float, @nondiff Float) -> Float)`
4554
// forms a curry thunk of `Float.+` before conversion to @differentiable, and AD
4655
// doesn't know how to differentiate the curry thunk, so it produces a
4756
// "function is not differentiable" error.
48-
// FIXME: Propagate wrt indices correctly so that this actually takes the
49-
// gradient wrt only the first parameter, as intended.
5057
// SupersetVJPTests.test("CrossModule") {
51-
// expectEqual(1, gradient(at: 1, 2, in: (+) as @differentiable (Float, @nondiff Float) -> Float))
58+
// let grad = gradient(at: Float(1), Float(2), in: (+) as @differentiable (Float, @nondiff Float) -> Float)
59+
// expectEqual(Float(1), grad)
5260
// }
5361

54-
// FIXME: Unbreak this one.
55-
//
56-
// @differentiable(wrt: (.0, .1), vjp: dx_T)
57-
// func x_T<T : Differentiable>(_ x: Float, _ y: T) -> Float {
58-
// if x > 1000 {
59-
// return x
60-
// }
61-
// return x
62-
// }
63-
// func dx_T<T>(_ x: Float, _ y: T) -> (Float, (Float) -> (Float, T.TangentVector)) {
64-
// return (x_T(x, y), { seed in (x, y) })
65-
// }
66-
// SupersetVJPTests.test("IndirectResults") {
67-
// expectEqual(3, gradient(at: 2) { x in x_T(x, Float(3)) })
68-
// }
62+
SupersetVJPTests.test("IndirectResults") {
63+
@differentiable(wrt: (x, y), vjp: dx_T)
64+
func x_T<T : Differentiable>(_ x: Float, _ y: T) -> Float {
65+
if x > 1000 { return x }
66+
return x
67+
}
68+
func dx_T<T : Differentiable>(_ x: Float, _ y: T) -> (Float, (Float) -> (Float, T.TangentVector)) {
69+
return (x_T(x, y), { v in (x * v, .zero) })
70+
}
71+
expectEqual(2, gradient(at: 2) { x in x_T(x, Float(3)) })
72+
}
6973

7074
runAllTests()

0 commit comments

Comments
 (0)