Skip to content

[AutoDiff] Another partial fix for adjoint memory leaks. #25300

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 8, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 67 additions & 45 deletions lib/SILOptimizer/Mandatory/Differentiation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3519,15 +3519,15 @@ class AdjointValueBase {

/// The underlying value.
union Value {
MutableArrayRef<AdjointValue> aggregate;
ArrayRef<AdjointValue> aggregate;
ValueWithCleanup concrete;
Value(MutableArrayRef<AdjointValue> v) : aggregate(v) {}
Value(ArrayRef<AdjointValue> v) : aggregate(v) {}
Value(ValueWithCleanup v) : concrete(v) {}
Value() {}
} value;

explicit AdjointValueBase(SILType type,
MutableArrayRef<AdjointValue> aggregate)
ArrayRef<AdjointValue> aggregate)
: kind(AdjointValueKind::Aggregate), type(type), value(aggregate) {}

explicit AdjointValueBase(ValueWithCleanup v)
Expand Down Expand Up @@ -3591,11 +3591,15 @@ class AdjointValue final {
return base->value.aggregate.size();
}

AdjointValue takeAggregateElement(unsigned i) {
AdjointValue getAggregateElement(unsigned i) const {
assert(isAggregate());
return base->value.aggregate[i];
}

ArrayRef<AdjointValue> getAggregateElements() const {
return base->value.aggregate;
}

ValueWithCleanup getConcreteValue() const {
assert(isConcrete());
return base->value.concrete;
Expand Down Expand Up @@ -3684,6 +3688,11 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
/// Mapping from original basic blocks to dominated active values.
DenseMap<SILBasicBlock *, SmallVector<SILValue, 8>> activeValues;

/// Local adjoint values to be cleaned up. This is populated when adjoint
/// emission is run on one basic block and cleaned before processing another
/// basic block.
SmallVector<AdjointValue, 8> blockLocalAdjointValues;

/// Mapping from original basic blocks and original active values to
/// corresponding adjoint block arguments.
DenseMap<std::pair<SILBasicBlock *, SILValue>, SILArgument *>
Expand Down Expand Up @@ -3758,7 +3767,7 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
AdjointValue makeAggregateAdjointValue(SILType type, EltRange elements);

//--------------------------------------------------------------------------//
// Managed value materializers
// Symbolic value materializers
//--------------------------------------------------------------------------//

/// Materialize an adjoint value. The type of the given adjoint value must be
Expand All @@ -3777,7 +3786,7 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
AdjointValue val, ValueWithCleanup &destBufferAccess);

//--------------------------------------------------------------------------//
// Helpers for managed value materializers
// Helpers for symbolic value materializers
//--------------------------------------------------------------------------//

/// Emit a zero value by calling `AdditiveArithmetic.zero`. The given type
Expand All @@ -3788,6 +3797,12 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
/// must conform to `AdditiveArithmetic` and be loadable in SIL.
SILValue emitZeroDirect(CanType type, SILLocation loc);

//--------------------------------------------------------------------------//
// Memory cleanup tools
//--------------------------------------------------------------------------//

void emitCleanupForAdjointValue(AdjointValue value);

//--------------------------------------------------------------------------//
// Accumulator
//--------------------------------------------------------------------------//
Expand Down Expand Up @@ -3899,8 +3914,9 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
auto it = insertion.first;
auto &&existingValue = it->getSecond();
valueMap.erase(it);
initializeAdjointValue(origBB, originalValue,
accumulateAdjointsDirect(existingValue, newAdjointValue));
auto adjVal = accumulateAdjointsDirect(existingValue, newAdjointValue);
initializeAdjointValue(origBB, originalValue, adjVal);
blockLocalAdjointValues.push_back(adjVal);
}

/// Get the adjoint block argument corresponding to the given original block
Expand Down Expand Up @@ -4378,7 +4394,7 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
// Emit cleanups for children.
if (auto *cleanup = concreteActiveValueAdj.getCleanup()) {
cleanup->disable();
cleanup->applyRecursively(builder, activeValue.getLoc());
cleanup->applyRecursively(builder, adjLoc);
}
trampolineArguments.push_back(concreteActiveValueAdj);
// If the adjoint block does not yet have a registered adjoint
Expand Down Expand Up @@ -4415,11 +4431,15 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
getPullbackInfo().lookUpPredecessorEnumElement(predBB, bb);
adjointSuccessorCases.push_back({enumEltDecl, adjointSuccBB});
}
assert(adjointSuccessorCases.size() == predEnum->getNumElements());
// Emit clenaups for all block-local adjoint values.
for (auto adjVal : blockLocalAdjointValues)
emitCleanupForAdjointValue(adjVal);
blockLocalAdjointValues.clear();
// - If the original block has exactly one predecessor, then the adjoint
// block has exactly one successor. Extract the pullback struct value
// from the predecessor enum value using `unchecked_enum_data` and
// branch to the adjoint successor block.
assert(adjointSuccessorCases.size() == predEnum->getNumElements());
if (adjointSuccessorCases.size() == 1) {
auto *predBB = bb->getSinglePredecessorBlock();
assert(predBB);
Expand Down Expand Up @@ -4479,9 +4499,13 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
indParamAdjoints.push_back(adjBuf);
}
};
// Accumulate differentiation parameter adjoints.
// Collect differentiation parameter adjoints.
for (auto i : getIndices().parameters->getIndices())
addRetElt(i);
// Emit cleanups for all local values.
for (auto adjVal : blockLocalAdjointValues)
emitCleanupForAdjointValue(adjVal);
blockLocalAdjointValues.clear();

// Disable cleanup for original indirect parameter adjoint buffers.
// Copy them to adjoint indirect results.
Expand Down Expand Up @@ -4667,8 +4691,9 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
builder.createDeallocStack(loc, tmpBuf);
}
else
addAdjointValue(bb, origArg, makeConcreteAdjointValue(ValueWithCleanup(
tan, makeCleanup(tan, emitCleanup, {seed.getCleanup()}))));
addAdjointValue(bb, origArg, makeConcreteAdjointValue(
ValueWithCleanup(tan,
makeCleanup(tan, emitCleanup, {seed.getCleanup()}))));
}
}
// Deallocate pullback indirect results.
Expand Down Expand Up @@ -4857,7 +4882,7 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
for (auto i : range(ti->getElements().size())) {
if (!getTangentSpace(ti->getElement(i)->getType().getASTType()))
continue;
addAdjointValue(bb, ti->getElement(i), av.takeAggregateElement(adjIdx++));
addAdjointValue(bb, ti->getElement(i), av.getAggregateElement(adjIdx++));
}
break;
}
Expand Down Expand Up @@ -4964,18 +4989,7 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
addAdjointValue(bb, si->getSrc(), makeConcreteAdjointValue(
ValueWithCleanup(adjVal, valueCleanup)));
// Set the buffer to zero, with a cleanup.
auto *bai = dyn_cast<BeginAccessInst>(adjBuf.getValue());
if (bai && !(bai->getAccessKind() == SILAccessKind::Modify ||
bai->getAccessKind() == SILAccessKind::Init)) {
auto *modifyAccess = builder.createBeginAccess(
si->getLoc(), bai->getSource(), SILAccessKind::Modify,
SILAccessEnforcement::Static, /*noNestedConflict*/ true,
/*fromBuiltin*/ false);
emitZeroIndirect(bufType.getASTType(), modifyAccess, si->getLoc());
builder.createEndAccess(si->getLoc(), modifyAccess, /*aborted*/ false);
} else {
emitZeroIndirect(bufType.getASTType(), adjBuf, si->getLoc());
}
emitZeroIndirect(bufType.getASTType(), adjBuf, si->getLoc());
}

// Handle `copy_addr` instruction.
Expand All @@ -5001,18 +5015,7 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
addToAdjointBuffer(bb, cai->getSrc(), readAccess);
builder.createEndAccess(cai->getLoc(), readAccess, /*aborted*/ false);
// Set the buffer to zero, with a cleanup.
auto *bai = dyn_cast<BeginAccessInst>(adjDest.getValue());
if (bai && !(bai->getAccessKind() == SILAccessKind::Modify ||
bai->getAccessKind() == SILAccessKind::Init)) {
auto *modifyAccess = builder.createBeginAccess(
cai->getLoc(), bai->getSource(), SILAccessKind::Modify,
SILAccessEnforcement::Static, /*noNestedConflict*/ true,
/*fromBuiltin*/ false);
emitZeroIndirect(destType.getASTType(), modifyAccess, cai->getLoc());
builder.createEndAccess(cai->getLoc(), modifyAccess, /*aborted*/ false);
} else {
emitZeroIndirect(destType.getASTType(), adjDest, cai->getLoc());
}
emitZeroIndirect(destType.getASTType(), adjDest, cai->getLoc());
auto cleanup = makeCleanup(adjDest, emitCleanup);
adjDest.setCleanup(cleanup);
}
Expand Down Expand Up @@ -5140,7 +5143,7 @@ ValueWithCleanup AdjointEmitter::materializeAdjointDirect(
SmallVector<SILValue, 8> elements;
SmallVector<Cleanup *, 8> cleanups;
for (auto i : range(val.getNumAggregateElements())) {
auto eltVal = materializeAdjointDirect(val.takeAggregateElement(i), loc);
auto eltVal = materializeAdjointDirect(val.getAggregateElement(i), loc);
elements.push_back(eltVal.getValue());
cleanups.push_back(eltVal.getCleanup());
}
Expand Down Expand Up @@ -5216,7 +5219,7 @@ void AdjointEmitter::materializeAdjointIndirectHelper(
ValueWithCleanup eltBuf(
builder.createTupleElementAddr(loc, destBufferAccess, idx, eltTy),
/*cleanup*/ nullptr);
materializeAdjointIndirectHelper(val.takeAggregateElement(idx), eltBuf);
materializeAdjointIndirectHelper(val.getAggregateElement(idx), eltBuf);
destBufferAccess.setCleanup(makeCleanupFromChildren(
{destBufferAccess.getCleanup(), eltBuf.getCleanup()}));
}
Expand All @@ -5228,7 +5231,7 @@ void AdjointEmitter::materializeAdjointIndirectHelper(
ValueWithCleanup eltBuf(
builder.createStructElementAddr(loc, destBufferAccess, *fieldIt),
/*cleanup*/ nullptr);
materializeAdjointIndirectHelper(val.takeAggregateElement(i), eltBuf);
materializeAdjointIndirectHelper(val.getAggregateElement(i), eltBuf);
destBufferAccess.setCleanup(makeCleanupFromChildren(
{destBufferAccess.getCleanup(), eltBuf.getCleanup()}));
}
Expand Down Expand Up @@ -5293,6 +5296,25 @@ SILValue AdjointEmitter::emitZeroDirect(CanType type, SILLocation loc) {
return loaded;
}

void AdjointEmitter::emitCleanupForAdjointValue(AdjointValue value) {
switch (value.getKind()) {
case AdjointValueKind::Zero: return;
case AdjointValueKind::Aggregate:
for (auto element : value.getAggregateElements())
emitCleanupForAdjointValue(element);
break;
case AdjointValueKind::Concrete: {
auto concrete = value.getConcreteValue();
auto *cleanup = concrete.getCleanup();
LLVM_DEBUG(getADDebugStream() << "Applying "
<< cleanup->getNumChildren() << " for value "
<< concrete.getValue() << " child cleanups\n");
cleanup->applyRecursively(builder, concrete.getLoc());
break;
}
}
}

AdjointValue
AdjointEmitter::accumulateAdjointsDirect(AdjointValue lhs,
AdjointValue rhs) {
Expand Down Expand Up @@ -5324,7 +5346,7 @@ AdjointEmitter::accumulateAdjointsDirect(AdjointValue lhs,
for (auto idx : range(rhs.getNumAggregateElements())) {
auto lhsElt = builder.createTupleExtract(
lhsVal.getLoc(), lhsVal, idx);
auto rhsElt = rhs.takeAggregateElement(idx);
auto rhsElt = rhs.getAggregateElement(idx);
newElements.push_back(accumulateAdjointsDirect(
makeConcreteAdjointValue(
ValueWithCleanup(lhsElt, lhsVal.getCleanup())),
Expand All @@ -5336,7 +5358,7 @@ AdjointEmitter::accumulateAdjointsDirect(AdjointValue lhs,
++fieldIt, ++i) {
auto lhsElt = builder.createStructExtract(
lhsVal.getLoc(), lhsVal, *fieldIt);
auto rhsElt = rhs.takeAggregateElement(i);
auto rhsElt = rhs.getAggregateElement(i);
newElements.push_back(accumulateAdjointsDirect(
makeConcreteAdjointValue(
ValueWithCleanup(lhsElt, lhsVal.getCleanup())),
Expand Down Expand Up @@ -5365,8 +5387,8 @@ AdjointEmitter::accumulateAdjointsDirect(AdjointValue lhs,
SmallVector<AdjointValue, 8> newElements;
for (auto i : range(lhs.getNumAggregateElements()))
newElements.push_back(
accumulateAdjointsDirect(lhs.takeAggregateElement(i),
rhs.takeAggregateElement(i)));
accumulateAdjointsDirect(lhs.getAggregateElement(i),
rhs.getAggregateElement(i)));
return makeAggregateAdjointValue(lhs.getType(), newElements);
}
}
Expand Down
6 changes: 3 additions & 3 deletions test/AutoDiff/leakchecking.swift
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ LeakCheckingTests.test("ControlFlow") {
// FIXME: Fix control flow AD memory leaks.
// See related FIXME comments in adjoint value/buffer propagation in
// lib/SILOptimizer/Mandatory/Differentiation.cpp.
testWithLeakChecking(expectedLeakCount: 74) {
testWithLeakChecking(expectedLeakCount: 41) {
func cond_nestedtuple_var(_ x: Tracked<Float>) -> Tracked<Float> {
// Convoluted function returning `x + x`.
var y = (x + x, x - x)
Expand All @@ -116,7 +116,7 @@ LeakCheckingTests.test("ControlFlow") {
// FIXME: Fix control flow AD memory leaks.
// See related FIXME comments in adjoint value/buffer propagation in
// lib/SILOptimizer/Mandatory/Differentiation.cpp.
testWithLeakChecking(expectedLeakCount: 300) {
testWithLeakChecking(expectedLeakCount: 193) {
func cond_nestedstruct_var(_ x: Tracked<Float>) -> Tracked<Float> {
// Convoluted function returning `x + x`.
var y = FloatPair(x + x, x - x)
Expand Down Expand Up @@ -157,7 +157,7 @@ LeakCheckingTests.test("ControlFlow") {
// FIXME: Fix control flow AD memory leaks.
// See related FIXME comments in adjoint value/buffer propagation in
// lib/SILOptimizer/Mandatory/Differentiation.cpp.
testWithLeakChecking(expectedLeakCount: 6) {
testWithLeakChecking(expectedLeakCount: 3) {
var model = ExampleLeakModel()
let x: Tracked<Float> = 1.0
_ = model.gradient(at: x) { m, x in
Expand Down