Skip to content

Commit 1a94205

Browse files
rxweidan-zheng
authored andcommitted
[AutoDiff] Another partial fix for adjoint memory leaks. (#25300)
In addition to emitting cleanups recursively for return values' child cleanups, we collect and clean up all adjoint values during adjoint emission per block.
1 parent c705359 commit 1a94205

File tree

2 files changed

+70
-48
lines changed

2 files changed

+70
-48
lines changed

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 67 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -3519,15 +3519,15 @@ class AdjointValueBase {
35193519

35203520
/// The underlying value.
35213521
union Value {
3522-
MutableArrayRef<AdjointValue> aggregate;
3522+
ArrayRef<AdjointValue> aggregate;
35233523
ValueWithCleanup concrete;
3524-
Value(MutableArrayRef<AdjointValue> v) : aggregate(v) {}
3524+
Value(ArrayRef<AdjointValue> v) : aggregate(v) {}
35253525
Value(ValueWithCleanup v) : concrete(v) {}
35263526
Value() {}
35273527
} value;
35283528

35293529
explicit AdjointValueBase(SILType type,
3530-
MutableArrayRef<AdjointValue> aggregate)
3530+
ArrayRef<AdjointValue> aggregate)
35313531
: kind(AdjointValueKind::Aggregate), type(type), value(aggregate) {}
35323532

35333533
explicit AdjointValueBase(ValueWithCleanup v)
@@ -3591,11 +3591,15 @@ class AdjointValue final {
35913591
return base->value.aggregate.size();
35923592
}
35933593

3594-
AdjointValue takeAggregateElement(unsigned i) {
3594+
AdjointValue getAggregateElement(unsigned i) const {
35953595
assert(isAggregate());
35963596
return base->value.aggregate[i];
35973597
}
35983598

3599+
ArrayRef<AdjointValue> getAggregateElements() const {
3600+
return base->value.aggregate;
3601+
}
3602+
35993603
ValueWithCleanup getConcreteValue() const {
36003604
assert(isConcrete());
36013605
return base->value.concrete;
@@ -3684,6 +3688,11 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
36843688
/// Mapping from original basic blocks to dominated active values.
36853689
DenseMap<SILBasicBlock *, SmallVector<SILValue, 8>> activeValues;
36863690

3691+
/// Local adjoint values to be cleaned up. This is populated when adjoint
3692+
/// emission is run on one basic block and cleaned before processing another
3693+
/// basic block.
3694+
SmallVector<AdjointValue, 8> blockLocalAdjointValues;
3695+
36873696
/// Mapping from original basic blocks and original active values to
36883697
/// corresponding adjoint block arguments.
36893698
DenseMap<std::pair<SILBasicBlock *, SILValue>, SILArgument *>
@@ -3758,7 +3767,7 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
37583767
AdjointValue makeAggregateAdjointValue(SILType type, EltRange elements);
37593768

37603769
//--------------------------------------------------------------------------//
3761-
// Managed value materializers
3770+
// Symbolic value materializers
37623771
//--------------------------------------------------------------------------//
37633772

37643773
/// Materialize an adjoint value. The type of the given adjoint value must be
@@ -3777,7 +3786,7 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
37773786
AdjointValue val, ValueWithCleanup &destBufferAccess);
37783787

37793788
//--------------------------------------------------------------------------//
3780-
// Helpers for managed value materializers
3789+
// Helpers for symbolic value materializers
37813790
//--------------------------------------------------------------------------//
37823791

37833792
/// Emit a zero value by calling `AdditiveArithmetic.zero`. The given type
@@ -3788,6 +3797,12 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
37883797
/// must conform to `AdditiveArithmetic` and be loadable in SIL.
37893798
SILValue emitZeroDirect(CanType type, SILLocation loc);
37903799

3800+
//--------------------------------------------------------------------------//
3801+
// Memory cleanup tools
3802+
//--------------------------------------------------------------------------//
3803+
3804+
void emitCleanupForAdjointValue(AdjointValue value);
3805+
37913806
//--------------------------------------------------------------------------//
37923807
// Accumulator
37933808
//--------------------------------------------------------------------------//
@@ -3899,8 +3914,9 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
38993914
auto it = insertion.first;
39003915
auto &&existingValue = it->getSecond();
39013916
valueMap.erase(it);
3902-
initializeAdjointValue(origBB, originalValue,
3903-
accumulateAdjointsDirect(existingValue, newAdjointValue));
3917+
auto adjVal = accumulateAdjointsDirect(existingValue, newAdjointValue);
3918+
initializeAdjointValue(origBB, originalValue, adjVal);
3919+
blockLocalAdjointValues.push_back(adjVal);
39043920
}
39053921

39063922
/// Get the adjoint block argument corresponding to the given original block
@@ -4378,7 +4394,7 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
43784394
// Emit cleanups for children.
43794395
if (auto *cleanup = concreteActiveValueAdj.getCleanup()) {
43804396
cleanup->disable();
4381-
cleanup->applyRecursively(builder, activeValue.getLoc());
4397+
cleanup->applyRecursively(builder, adjLoc);
43824398
}
43834399
trampolineArguments.push_back(concreteActiveValueAdj);
43844400
// If the adjoint block does not yet have a registered adjoint
@@ -4415,11 +4431,15 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
44154431
getPullbackInfo().lookUpPredecessorEnumElement(predBB, bb);
44164432
adjointSuccessorCases.push_back({enumEltDecl, adjointSuccBB});
44174433
}
4418-
assert(adjointSuccessorCases.size() == predEnum->getNumElements());
4434+
// Emit clenaups for all block-local adjoint values.
4435+
for (auto adjVal : blockLocalAdjointValues)
4436+
emitCleanupForAdjointValue(adjVal);
4437+
blockLocalAdjointValues.clear();
44194438
// - If the original block has exactly one predecessor, then the adjoint
44204439
// block has exactly one successor. Extract the pullback struct value
44214440
// from the predecessor enum value using `unchecked_enum_data` and
44224441
// branch to the adjoint successor block.
4442+
assert(adjointSuccessorCases.size() == predEnum->getNumElements());
44234443
if (adjointSuccessorCases.size() == 1) {
44244444
auto *predBB = bb->getSinglePredecessorBlock();
44254445
assert(predBB);
@@ -4479,9 +4499,13 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
44794499
indParamAdjoints.push_back(adjBuf);
44804500
}
44814501
};
4482-
// Accumulate differentiation parameter adjoints.
4502+
// Collect differentiation parameter adjoints.
44834503
for (auto i : getIndices().parameters->getIndices())
44844504
addRetElt(i);
4505+
// Emit cleanups for all local values.
4506+
for (auto adjVal : blockLocalAdjointValues)
4507+
emitCleanupForAdjointValue(adjVal);
4508+
blockLocalAdjointValues.clear();
44854509

44864510
// Disable cleanup for original indirect parameter adjoint buffers.
44874511
// Copy them to adjoint indirect results.
@@ -4667,8 +4691,9 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
46674691
builder.createDeallocStack(loc, tmpBuf);
46684692
}
46694693
else
4670-
addAdjointValue(bb, origArg, makeConcreteAdjointValue(ValueWithCleanup(
4671-
tan, makeCleanup(tan, emitCleanup, {seed.getCleanup()}))));
4694+
addAdjointValue(bb, origArg, makeConcreteAdjointValue(
4695+
ValueWithCleanup(tan,
4696+
makeCleanup(tan, emitCleanup, {seed.getCleanup()}))));
46724697
}
46734698
}
46744699
// Deallocate pullback indirect results.
@@ -4857,7 +4882,7 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
48574882
for (auto i : range(ti->getElements().size())) {
48584883
if (!getTangentSpace(ti->getElement(i)->getType().getASTType()))
48594884
continue;
4860-
addAdjointValue(bb, ti->getElement(i), av.takeAggregateElement(adjIdx++));
4885+
addAdjointValue(bb, ti->getElement(i), av.getAggregateElement(adjIdx++));
48614886
}
48624887
break;
48634888
}
@@ -4964,18 +4989,7 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
49644989
addAdjointValue(bb, si->getSrc(), makeConcreteAdjointValue(
49654990
ValueWithCleanup(adjVal, valueCleanup)));
49664991
// Set the buffer to zero, with a cleanup.
4967-
auto *bai = dyn_cast<BeginAccessInst>(adjBuf.getValue());
4968-
if (bai && !(bai->getAccessKind() == SILAccessKind::Modify ||
4969-
bai->getAccessKind() == SILAccessKind::Init)) {
4970-
auto *modifyAccess = builder.createBeginAccess(
4971-
si->getLoc(), bai->getSource(), SILAccessKind::Modify,
4972-
SILAccessEnforcement::Static, /*noNestedConflict*/ true,
4973-
/*fromBuiltin*/ false);
4974-
emitZeroIndirect(bufType.getASTType(), modifyAccess, si->getLoc());
4975-
builder.createEndAccess(si->getLoc(), modifyAccess, /*aborted*/ false);
4976-
} else {
4977-
emitZeroIndirect(bufType.getASTType(), adjBuf, si->getLoc());
4978-
}
4992+
emitZeroIndirect(bufType.getASTType(), adjBuf, si->getLoc());
49794993
}
49804994

49814995
// Handle `copy_addr` instruction.
@@ -5001,18 +5015,7 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
50015015
addToAdjointBuffer(bb, cai->getSrc(), readAccess);
50025016
builder.createEndAccess(cai->getLoc(), readAccess, /*aborted*/ false);
50035017
// Set the buffer to zero, with a cleanup.
5004-
auto *bai = dyn_cast<BeginAccessInst>(adjDest.getValue());
5005-
if (bai && !(bai->getAccessKind() == SILAccessKind::Modify ||
5006-
bai->getAccessKind() == SILAccessKind::Init)) {
5007-
auto *modifyAccess = builder.createBeginAccess(
5008-
cai->getLoc(), bai->getSource(), SILAccessKind::Modify,
5009-
SILAccessEnforcement::Static, /*noNestedConflict*/ true,
5010-
/*fromBuiltin*/ false);
5011-
emitZeroIndirect(destType.getASTType(), modifyAccess, cai->getLoc());
5012-
builder.createEndAccess(cai->getLoc(), modifyAccess, /*aborted*/ false);
5013-
} else {
5014-
emitZeroIndirect(destType.getASTType(), adjDest, cai->getLoc());
5015-
}
5018+
emitZeroIndirect(destType.getASTType(), adjDest, cai->getLoc());
50165019
auto cleanup = makeCleanup(adjDest, emitCleanup);
50175020
adjDest.setCleanup(cleanup);
50185021
}
@@ -5140,7 +5143,7 @@ ValueWithCleanup AdjointEmitter::materializeAdjointDirect(
51405143
SmallVector<SILValue, 8> elements;
51415144
SmallVector<Cleanup *, 8> cleanups;
51425145
for (auto i : range(val.getNumAggregateElements())) {
5143-
auto eltVal = materializeAdjointDirect(val.takeAggregateElement(i), loc);
5146+
auto eltVal = materializeAdjointDirect(val.getAggregateElement(i), loc);
51445147
elements.push_back(eltVal.getValue());
51455148
cleanups.push_back(eltVal.getCleanup());
51465149
}
@@ -5216,7 +5219,7 @@ void AdjointEmitter::materializeAdjointIndirectHelper(
52165219
ValueWithCleanup eltBuf(
52175220
builder.createTupleElementAddr(loc, destBufferAccess, idx, eltTy),
52185221
/*cleanup*/ nullptr);
5219-
materializeAdjointIndirectHelper(val.takeAggregateElement(idx), eltBuf);
5222+
materializeAdjointIndirectHelper(val.getAggregateElement(idx), eltBuf);
52205223
destBufferAccess.setCleanup(makeCleanupFromChildren(
52215224
{destBufferAccess.getCleanup(), eltBuf.getCleanup()}));
52225225
}
@@ -5228,7 +5231,7 @@ void AdjointEmitter::materializeAdjointIndirectHelper(
52285231
ValueWithCleanup eltBuf(
52295232
builder.createStructElementAddr(loc, destBufferAccess, *fieldIt),
52305233
/*cleanup*/ nullptr);
5231-
materializeAdjointIndirectHelper(val.takeAggregateElement(i), eltBuf);
5234+
materializeAdjointIndirectHelper(val.getAggregateElement(i), eltBuf);
52325235
destBufferAccess.setCleanup(makeCleanupFromChildren(
52335236
{destBufferAccess.getCleanup(), eltBuf.getCleanup()}));
52345237
}
@@ -5293,6 +5296,25 @@ SILValue AdjointEmitter::emitZeroDirect(CanType type, SILLocation loc) {
52935296
return loaded;
52945297
}
52955298

5299+
void AdjointEmitter::emitCleanupForAdjointValue(AdjointValue value) {
5300+
switch (value.getKind()) {
5301+
case AdjointValueKind::Zero: return;
5302+
case AdjointValueKind::Aggregate:
5303+
for (auto element : value.getAggregateElements())
5304+
emitCleanupForAdjointValue(element);
5305+
break;
5306+
case AdjointValueKind::Concrete: {
5307+
auto concrete = value.getConcreteValue();
5308+
auto *cleanup = concrete.getCleanup();
5309+
LLVM_DEBUG(getADDebugStream() << "Applying "
5310+
<< cleanup->getNumChildren() << " for value "
5311+
<< concrete.getValue() << " child cleanups\n");
5312+
cleanup->applyRecursively(builder, concrete.getLoc());
5313+
break;
5314+
}
5315+
}
5316+
}
5317+
52965318
AdjointValue
52975319
AdjointEmitter::accumulateAdjointsDirect(AdjointValue lhs,
52985320
AdjointValue rhs) {
@@ -5324,7 +5346,7 @@ AdjointEmitter::accumulateAdjointsDirect(AdjointValue lhs,
53245346
for (auto idx : range(rhs.getNumAggregateElements())) {
53255347
auto lhsElt = builder.createTupleExtract(
53265348
lhsVal.getLoc(), lhsVal, idx);
5327-
auto rhsElt = rhs.takeAggregateElement(idx);
5349+
auto rhsElt = rhs.getAggregateElement(idx);
53285350
newElements.push_back(accumulateAdjointsDirect(
53295351
makeConcreteAdjointValue(
53305352
ValueWithCleanup(lhsElt, lhsVal.getCleanup())),
@@ -5336,7 +5358,7 @@ AdjointEmitter::accumulateAdjointsDirect(AdjointValue lhs,
53365358
++fieldIt, ++i) {
53375359
auto lhsElt = builder.createStructExtract(
53385360
lhsVal.getLoc(), lhsVal, *fieldIt);
5339-
auto rhsElt = rhs.takeAggregateElement(i);
5361+
auto rhsElt = rhs.getAggregateElement(i);
53405362
newElements.push_back(accumulateAdjointsDirect(
53415363
makeConcreteAdjointValue(
53425364
ValueWithCleanup(lhsElt, lhsVal.getCleanup())),
@@ -5365,8 +5387,8 @@ AdjointEmitter::accumulateAdjointsDirect(AdjointValue lhs,
53655387
SmallVector<AdjointValue, 8> newElements;
53665388
for (auto i : range(lhs.getNumAggregateElements()))
53675389
newElements.push_back(
5368-
accumulateAdjointsDirect(lhs.takeAggregateElement(i),
5369-
rhs.takeAggregateElement(i)));
5390+
accumulateAdjointsDirect(lhs.getAggregateElement(i),
5391+
rhs.getAggregateElement(i)));
53705392
return makeAggregateAdjointValue(lhs.getType(), newElements);
53715393
}
53725394
}

test/AutoDiff/leakchecking.swift

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ LeakCheckingTests.test("ControlFlow") {
9292
// FIXME: Fix control flow AD memory leaks.
9393
// See related FIXME comments in adjoint value/buffer propagation in
9494
// lib/SILOptimizer/Mandatory/Differentiation.cpp.
95-
testWithLeakChecking(expectedLeakCount: 74) {
95+
testWithLeakChecking(expectedLeakCount: 41) {
9696
func cond_nestedtuple_var(_ x: Tracked<Float>) -> Tracked<Float> {
9797
// Convoluted function returning `x + x`.
9898
var y = (x + x, x - x)
@@ -116,7 +116,7 @@ LeakCheckingTests.test("ControlFlow") {
116116
// FIXME: Fix control flow AD memory leaks.
117117
// See related FIXME comments in adjoint value/buffer propagation in
118118
// lib/SILOptimizer/Mandatory/Differentiation.cpp.
119-
testWithLeakChecking(expectedLeakCount: 300) {
119+
testWithLeakChecking(expectedLeakCount: 193) {
120120
func cond_nestedstruct_var(_ x: Tracked<Float>) -> Tracked<Float> {
121121
// Convoluted function returning `x + x`.
122122
var y = FloatPair(x + x, x - x)
@@ -157,7 +157,7 @@ LeakCheckingTests.test("ControlFlow") {
157157
// FIXME: Fix control flow AD memory leaks.
158158
// See related FIXME comments in adjoint value/buffer propagation in
159159
// lib/SILOptimizer/Mandatory/Differentiation.cpp.
160-
testWithLeakChecking(expectedLeakCount: 6) {
160+
testWithLeakChecking(expectedLeakCount: 3) {
161161
var model = ExampleLeakModel()
162162
let x: Tracked<Float> = 1.0
163163
_ = model.gradient(at: x) { m, x in

0 commit comments

Comments
 (0)