Skip to content

Commit c705359

Browse files
authored
[AutoDiff] Partially fix AD control flow memory leaks (#25294)
This PR partially fixes the incorrect omission of memory cleanup in `AdjointEmitter`. Cleanups should be created per basic block in an adjoint function, starting with its arguments. Like function parameter arguments in the adjoint entry block, every phi argument except the pullback struct in a non-entry adjoint block also needs to possess a cleanup. We let the existing adjoint emission logic accumulate the cleanups until we are about to branch, where we disable the top-level cleanup for arguments in the terminator instruction and apply all of their child cleanups recursively. This logic is more general than the original single-block differentiation logic, and we should refactor `AdjointEmitter::run()` to eliminate special logic for adjoint entry and exit blocks as much as possible. This PR also unifies the argument order of the adjoint entry block with that of other adjoint blocks, making the pullback struct always the last argument. This helps us reduce special logic for the adjoint entry block later.
1 parent f9872fc commit c705359

File tree

3 files changed

+47
-28
lines changed

3 files changed

+47
-28
lines changed

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4106,6 +4106,10 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
41064106
}
41074107

41084108
public:
4109+
//--------------------------------------------------------------------------//
4110+
// Entry point
4111+
//--------------------------------------------------------------------------//
4112+
41094113
/// Performs adjoint synthesis on the empty adjoint function. Returns true if
41104114
/// any error occurs.
41114115
bool run() {
@@ -4187,11 +4191,6 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
41874191
continue;
41884192
}
41894193

4190-
// Otherwise, we create a phi argument for the corresponding pullback
4191-
// struct, and handle dominated active values/buffers.
4192-
auto *pbStructArg = adjointBB->createPhiArgument(
4193-
pbStructLoweredType, ValueOwnershipKind::Guaranteed);
4194-
adjointPullbackStructArguments[origBB] = pbStructArg;
41954194
// Get all active values in the original block.
41964195
// If the original block has no active values, continue.
41974196
auto &bbActiveValues = activeValues[origBB];
@@ -4217,6 +4216,10 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
42174216
activeValueAdjointBBArgumentMap[{origBB, activeValue}] = adjointArg;
42184217
}
42194218
}
4219+
// Add a pullback struct argument.
4220+
auto *pbStructArg = adjointBB->createPhiArgument(
4221+
pbStructLoweredType, ValueOwnershipKind::Guaranteed);
4222+
adjointPullbackStructArguments[origBB] = pbStructArg;
42204223
// - Create adjoint trampoline blocks for each successor block of the
42214224
// original block. Adjoint trampoline blocks only have a pullback
42224225
// struct argument, and branch from the adjoint successor block to the
@@ -4364,8 +4367,6 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
43644367
assert(adjointSuccBB && adjointSuccBB->getNumArguments() == 1);
43654368
SILBuilder adjointTrampolineBBBuilder(adjointSuccBB);
43664369
SmallVector<SILValue, 8> trampolineArguments;
4367-
// Propagate pullback struct argument.
4368-
trampolineArguments.push_back(adjointSuccBB->getArguments().front());
43694370
// Propagate adjoint values/buffers of active values/buffers to
43704371
// predecessor blocks.
43714372
auto &predBBActiveValues = activeValues[predBB];
@@ -4374,6 +4375,11 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
43744375
auto activeValueAdj = getAdjointValue(bb, activeValue);
43754376
auto concreteActiveValueAdj =
43764377
materializeAdjointDirect(activeValueAdj, adjLoc);
4378+
// Emit cleanups for children.
4379+
if (auto *cleanup = concreteActiveValueAdj.getCleanup()) {
4380+
cleanup->disable();
4381+
cleanup->applyRecursively(builder, activeValue.getLoc());
4382+
}
43774383
trampolineArguments.push_back(concreteActiveValueAdj);
43784384
// If the adjoint block does not yet have a registered adjoint
43794385
// value for the active value, set the adjoint value to the
@@ -4383,9 +4389,9 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
43834389
if (!hasAdjointValue(predBB, activeValue)) {
43844390
auto *adjointBBArg =
43854391
getActiveValueAdjointBlockArgument(predBB, activeValue);
4386-
// FIXME: Propagate cleanups to fix memory leaks.
4387-
auto forwardedArgAdj =
4388-
makeConcreteAdjointValue(ValueWithCleanup(adjointBBArg));
4392+
auto forwardedArgAdj = makeConcreteAdjointValue(
4393+
ValueWithCleanup(adjointBBArg,
4394+
makeCleanup(adjointBBArg, emitCleanup)));
43894395
initializeAdjointValue(predBB, activeValue, forwardedArgAdj);
43904396
}
43914397
} else {
@@ -4399,6 +4405,8 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
43994405
adjLoc, adjBuf, predAdjBuf, IsNotTake, IsNotInitialization);
44004406
}
44014407
}
4408+
// Propagate pullback struct argument.
4409+
trampolineArguments.push_back(adjointSuccBB->getArguments().front());
44024410
// Branch from adjoint trampoline block to adjoint block.
44034411
adjointTrampolineBBBuilder.createBranch(
44044412
adjLoc, adjointBB, trampolineArguments);

test/AutoDiff/control_flow_sil.swift

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -72,32 +72,43 @@ func cond(_ x: Float) -> Float {
7272
// CHECK-SIL: switch_enum [[BB3_PRED]] : $_AD__cond_bb3__Pred__src_0_wrt_0, case #_AD__cond_bb3__Pred__src_0_wrt_0.bb2!enumelt.1: bb3, case #_AD__cond_bb3__Pred__src_0_wrt_0.bb1!enumelt.1: bb1
7373

7474
// CHECK-SIL: bb1([[BB3_PRED1_TRAMP_PB_STRUCT:%.*]] : $_AD__cond_bb1__PB__src_0_wrt_0):
75-
// CHECK-SIL: br bb2([[BB3_PRED1_TRAMP_PB_STRUCT]] : $_AD__cond_bb1__PB__src_0_wrt_0, {{%.*}} : $Float, {{%.*}}: $Float)
75+
// CHECK-SIL: br bb2({{%.*}} : $Float, {{%.*}}: $Float, [[BB3_PRED1_TRAMP_PB_STRUCT]] : $_AD__cond_bb1__PB__src_0_wrt_0)
7676

77-
// CHECK-SIL: bb2([[BB1_PB_STRUCT:%.*]] : $_AD__cond_bb1__PB__src_0_wrt_0, {{%.*}} : $Float, {{%.*}} : $Float):
77+
// CHECK-SIL: bb2({{%.*}} : $Float, {{%.*}} : $Float, [[BB1_PB_STRUCT:%.*]] : $_AD__cond_bb1__PB__src_0_wrt_0):
7878
// CHECK-SIL: [[BB1_PB:%.*]] = struct_extract [[BB1_PB_STRUCT]]
7979
// CHECK-SIL: [[BB1_ADJVALS:%.*]] = apply [[BB1_PB]]([[SEED]]) : $@callee_guaranteed (Float) -> (Float, Float)
8080
// CHECK-SIL: [[BB1_PRED:%.*]] = struct_extract [[BB1_PB_STRUCT]]
81+
// CHECK-SIL: release_value {{%.*}} : $Float
82+
// CHECK-SIL: release_value {{%.*}} : $Float
83+
// CHECK-SIL: release_value {{%.*}} : $Float
84+
// CHECK-SIL: release_value {{%.*}} : $Float
85+
// CHECK-SIL: release_value {{%.*}} : $Float
8186
// CHECK-SIL: [[BB1_PB_STRUCT_DATA:%.*]] = unchecked_enum_data [[BB1_PRED]]
8287
// CHECK-SIL: br bb5([[BB1_PB_STRUCT_DATA]] : $_AD__cond_bb0__PB__src_0_wrt_0)
8388

8489
// CHECK-SIL: bb3([[BB3_PRED2_TRAMP_PB_STRUCT:%.*]] : $_AD__cond_bb2__PB__src_0_wrt_0):
85-
// CHECK-SIL: br bb4([[BB3_PRED2_TRAMP_PB_STRUCT]] : $_AD__cond_bb2__PB__src_0_wrt_0, {{%.*}} : $Float, {{%.*}}: $Float)
90+
// CHECK-SIL: br bb4({{%.*}} : $Float, {{%.*}}: $Float, [[BB3_PRED2_TRAMP_PB_STRUCT]] : $_AD__cond_bb2__PB__src_0_wrt_0)
8691

87-
// CHECK-SIL: bb4([[BB2_PB_STRUCT:%.*]] : $_AD__cond_bb2__PB__src_0_wrt_0, {{%.*}} : $Float, {{%.*}} : $Float):
92+
// CHECK-SIL: bb4({{%.*}} : $Float, {{%.*}} : $Float, [[BB2_PB_STRUCT:%.*]] : $_AD__cond_bb2__PB__src_0_wrt_0):
8893
// CHECK-SIL: [[BB2_PB:%.*]] = struct_extract [[BB2_PB_STRUCT]]
8994
// CHECK-SIL: [[BB2_ADJVALS:%.*]] = apply [[BB2_PB]]([[SEED]]) : $@callee_guaranteed (Float) -> (Float, Float)
9095
// CHECK-SIL: [[BB2_PRED:%.*]] = struct_extract [[BB2_PB_STRUCT]]
96+
// CHECK-SIL: release_value {{%.*}} : $Float
97+
// CHECK-SIL: release_value {{%.*}} : $Float
98+
// CHECK-SIL: release_value {{%.*}} : $Float
99+
// CHECK-SIL: release_value {{%.*}} : $Float
91100
// CHECK-SIL: [[BB2_PB_STRUCT_DATA:%.*]] = unchecked_enum_data [[BB2_PRED]]
92101
// CHECK-SIL: br bb6([[BB2_PB_STRUCT_DATA]] : $_AD__cond_bb0__PB__src_0_wrt_0)
93102

94103
// CHECK-SIL: bb5([[BB1_PRED0_TRAMP_PB_STRUCT:%.*]] : $_AD__cond_bb0__PB__src_0_wrt_0):
95-
// CHECK-SIL: br bb7([[BB1_PRED0_TRAMP_PB_STRUCT]] : $_AD__cond_bb0__PB__src_0_wrt_0, {{%.*}} : $Float)
104+
// CHECK-SIL: br bb7({{%.*}} : $Float, [[BB1_PRED0_TRAMP_PB_STRUCT]] : $_AD__cond_bb0__PB__src_0_wrt_0)
96105

97106
// CHECK-SIL: bb6([[BB2_PRED0_TRAMP_PB_STRUCT:%.*]] : $_AD__cond_bb0__PB__src_0_wrt_0):
98-
// CHECK-SIL: br bb7([[BB2_PRED0_TRAMP_PB_STRUCT]] : $_AD__cond_bb0__PB__src_0_wrt_0, {{%.*}} : $Float)
107+
// CHECK-SIL: br bb7({{%.*}} : $Float, [[BB2_PRED0_TRAMP_PB_STRUCT]] : $_AD__cond_bb0__PB__src_0_wrt_0)
99108

100-
// CHECK-SIL: bb7([[BB0_PB_STRUCT:%.*]] : $_AD__cond_bb0__PB__src_0_wrt_0, {{%.*}} : $Float):
109+
// CHECK-SIL: bb7({{%.*}} : $Float, [[BB0_PB_STRUCT:%.*]] : $_AD__cond_bb0__PB__src_0_wrt_0):
110+
// CHECK-SIL: release_value {{%.*}} : $Float
111+
// CHECK-SIL: release_value {{%.*}} : $Float
101112
// CHECK-SIL: return {{%.*}} : $Float
102113

103114
@differentiable
@@ -147,30 +158,30 @@ func cond_tuple_var(_ x: Float) -> Float {
147158
// CHECK-SIL: switch_enum [[BB3_PRED]] : $_AD__cond_tuple_var_bb3__Pred__src_0_wrt_0, case #_AD__cond_tuple_var_bb3__Pred__src_0_wrt_0.bb2!enumelt.1: bb3, case #_AD__cond_tuple_var_bb3__Pred__src_0_wrt_0.bb1!enumelt.1: bb1
148159

149160
// CHECK-SIL: bb1([[BB3_PRED1_TRAMP_PB_STRUCT:%.*]] : $_AD__cond_tuple_var_bb1__PB__src_0_wrt_0):
150-
// CHECK-SIL: br bb2([[BB3_PRED1_TRAMP_PB_STRUCT]] : $_AD__cond_tuple_var_bb1__PB__src_0_wrt_0, {{%.*}} : $Float, {{%.*}}: $Float)
161+
// CHECK-SIL: br bb2({{%.*}} : $Float, {{%.*}} : $Float, [[BB3_PRED1_TRAMP_PB_STRUCT]] : $_AD__cond_tuple_var_bb1__PB__src_0_wrt_0)
151162

152-
// CHECK-SIL: bb2([[BB1_PB_STRUCT:%.*]] : $_AD__cond_tuple_var_bb1__PB__src_0_wrt_0, {{%.*}} : $Float, {{%.*}} : $Float):
163+
// CHECK-SIL: bb2({{%.*}} : $Float, {{%.*}} : $Float, [[BB1_PB_STRUCT:%.*]] : $_AD__cond_tuple_var_bb1__PB__src_0_wrt_0):
153164
// CHECK-SIL: [[BB1_PRED:%.*]] = struct_extract [[BB1_PB_STRUCT]]
154165
// CHECK-SIL: copy_addr {{%.*}} to {{%.*}} : $*(Float, Float)
155166
// CHECK-SIL-NOT: copy_addr {{%.*}} to {{%.*}} : $*Float
156167
// CHECK-SIL: [[BB1_PB_STRUCT_DATA:%.*]] = unchecked_enum_data [[BB1_PRED]]
157168
// CHECK-SIL: br bb5([[BB1_PB_STRUCT_DATA]] : $_AD__cond_tuple_var_bb0__PB__src_0_wrt_0)
158169

159170
// CHECK-SIL: bb3([[BB3_PRED2_TRAMP_PB_STRUCT:%.*]] : $_AD__cond_tuple_var_bb2__PB__src_0_wrt_0):
160-
// CHECK-SIL: br bb4([[BB3_PRED2_TRAMP_PB_STRUCT]] : $_AD__cond_tuple_var_bb2__PB__src_0_wrt_0, {{%.*}} : $Float, {{%.*}}: $Float)
171+
// CHECK-SIL: br bb4({{%.*}} : $Float, {{%.*}} : $Float, [[BB3_PRED2_TRAMP_PB_STRUCT]] : $_AD__cond_tuple_var_bb2__PB__src_0_wrt_0)
161172

162-
// CHECK-SIL: bb4([[BB2_PB_STRUCT:%.*]] : $_AD__cond_tuple_var_bb2__PB__src_0_wrt_0, {{%.*}} : $Float, {{%.*}} : $Float):
173+
// CHECK-SIL: bb4({{%.*}} : $Float, {{%.*}} : $Float, [[BB2_PB_STRUCT:%.*]] : $_AD__cond_tuple_var_bb2__PB__src_0_wrt_0):
163174
// CHECK-SIL: [[BB2_PRED:%.*]] = struct_extract [[BB2_PB_STRUCT]]
164175
// CHECK-SIL: copy_addr {{%.*}} to {{%.*}} : $*(Float, Float)
165176
// CHECK-SIL-NOT: copy_addr {{%.*}} to {{%.*}} : $*Float
166177
// CHECK-SIL: [[BB2_PB_STRUCT_DATA:%.*]] = unchecked_enum_data [[BB2_PRED]]
167178
// CHECK-SIL: br bb6([[BB2_PB_STRUCT_DATA]] : $_AD__cond_tuple_var_bb0__PB__src_0_wrt_0)
168179

169180
// CHECK-SIL: bb5([[BB1_PRED0_TRAMP_PB_STRUCT:%.*]] : $_AD__cond_tuple_var_bb0__PB__src_0_wrt_0):
170-
// CHECK-SIL: br bb7([[BB1_PRED0_TRAMP_PB_STRUCT]] : $_AD__cond_tuple_var_bb0__PB__src_0_wrt_0, {{%.*}} : $Float)
181+
// CHECK-SIL: br bb7({{%.*}} : $Float, [[BB1_PRED0_TRAMP_PB_STRUCT]] : $_AD__cond_tuple_var_bb0__PB__src_0_wrt_0)
171182

172183
// CHECK-SIL: bb6([[BB2_PRED0_TRAMP_PB_STRUCT:%.*]] : $_AD__cond_tuple_var_bb0__PB__src_0_wrt_0):
173-
// CHECK-SIL: br bb7([[BB2_PRED0_TRAMP_PB_STRUCT]] : $_AD__cond_tuple_var_bb0__PB__src_0_wrt_0, {{%.*}} : $Float)
184+
// CHECK-SIL: br bb7({{%.*}} : $Float, [[BB2_PRED0_TRAMP_PB_STRUCT]] : $_AD__cond_tuple_var_bb0__PB__src_0_wrt_0)
174185

175-
// CHECK-SIL: bb7([[BB0_PB_STRUCT:%.*]] : $_AD__cond_tuple_var_bb0__PB__src_0_wrt_0, {{%.*}} : $Float):
186+
// CHECK-SIL: bb7({{%.*}} : $Float, [[BB0_PB_STRUCT:%.*]] : $_AD__cond_tuple_var_bb0__PB__src_0_wrt_0):
176187
// CHECK-SIL: return {{%.*}} : $Float

test/AutoDiff/leakchecking.swift

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ LeakCheckingTests.test("ControlFlow") {
9292
// FIXME: Fix control flow AD memory leaks.
9393
// See related FIXME comments in adjoint value/buffer propagation in
9494
// lib/SILOptimizer/Mandatory/Differentiation.cpp.
95-
testWithLeakChecking(expectedLeakCount: 105) {
95+
testWithLeakChecking(expectedLeakCount: 74) {
9696
func cond_nestedtuple_var(_ x: Tracked<Float>) -> Tracked<Float> {
9797
// Convoluted function returning `x + x`.
9898
var y = (x + x, x - x)
@@ -116,7 +116,7 @@ LeakCheckingTests.test("ControlFlow") {
116116
// FIXME: Fix control flow AD memory leaks.
117117
// See related FIXME comments in adjoint value/buffer propagation in
118118
// lib/SILOptimizer/Mandatory/Differentiation.cpp.
119-
testWithLeakChecking(expectedLeakCount: 379) {
119+
testWithLeakChecking(expectedLeakCount: 300) {
120120
func cond_nestedstruct_var(_ x: Tracked<Float>) -> Tracked<Float> {
121121
// Convoluted function returning `x + x`.
122122
var y = FloatPair(x + x, x - x)
@@ -140,7 +140,7 @@ LeakCheckingTests.test("ControlFlow") {
140140
// FIXME: Fix control flow AD memory leaks.
141141
// See related FIXME comments in adjoint value/buffer propagation in
142142
// lib/SILOptimizer/Mandatory/Differentiation.cpp.
143-
testWithLeakChecking(expectedLeakCount: 9) {
143+
testWithLeakChecking(expectedLeakCount: 3) {
144144
var model = ExampleLeakModel()
145145
let x: Tracked<Float> = 1.0
146146
_ = model.gradient(at: x) { m, x in
@@ -157,7 +157,7 @@ LeakCheckingTests.test("ControlFlow") {
157157
// FIXME: Fix control flow AD memory leaks.
158158
// See related FIXME comments in adjoint value/buffer propagation in
159159
// lib/SILOptimizer/Mandatory/Differentiation.cpp.
160-
testWithLeakChecking(expectedLeakCount: 14) {
160+
testWithLeakChecking(expectedLeakCount: 6) {
161161
var model = ExampleLeakModel()
162162
let x: Tracked<Float> = 1.0
163163
_ = model.gradient(at: x) { m, x in

0 commit comments

Comments
 (0)