Skip to content

Commit d799852

Browse files
authored
[AutoDiff] Fix a potential use-after-free and a potential memory leak. (#26345)
In e8a53ae, fix a use-after-free when a function being differentiated captures an address-only value. The fix is to create a copied buffer for address-only `partial_apply` arguments during `reapplyFunctionConversions`. In d14bb7f, fix a memory leak when a pullback returns a loadable value indirectly. In c95b6bd, remove `emitCleanup` and use `SILBuilder::emitDestroyAddrAndFold` and `SILBuilder::emitReleaseValueAndFold` for code clarity. Luckily, neither of the memory issues has occurred yet partly because we have not migrated to [ad-all-indirect](https://github.com/rxwei/swift/tree/ad-all-indirect) yet.
1 parent 781d8e3 commit d799852

File tree

2 files changed

+38
-29
lines changed

2 files changed

+38
-29
lines changed

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 28 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1916,14 +1916,6 @@ static SILValue joinElements(ArrayRef<SILValue> elements, SILBuilder &builder,
19161916
return builder.createTuple(loc, elements);
19171917
}
19181918

1919-
// Emits a release based on the value's type category (address or object).
1920-
static void emitCleanup(SILBuilder &builder, SILLocation loc, SILValue v) {
1921-
if (v->getType().isAddress())
1922-
builder.createDestroyAddr(loc, v);
1923-
else
1924-
builder.createReleaseValue(loc, v, builder.getDefaultAtomicity());
1925-
}
1926-
19271919
/// When a function value is used in an instruction (usually `apply`), there's
19281920
/// some conversion instruction in between, e.g. `thin_to_thick_function`. Given
19291921
/// a new function value and an old function value, this helper function
@@ -1953,14 +1945,26 @@ reapplyFunctionConversion(SILValue newFunc, SILValue oldFunc,
19531945
if (auto *pai = dyn_cast<PartialApplyInst>(oldConvertedFunc)) {
19541946
SmallVector<SILValue, 8> newArgs;
19551947
newArgs.reserve(pai->getNumArguments());
1948+
SmallVector<AllocStackInst *, 1> copiedIndirectParams;
1949+
SWIFT_DEFER {
1950+
for (auto *alloc : reversed(copiedIndirectParams))
1951+
builder.createDeallocStack(loc, alloc);
1952+
};
19561953
for (auto arg : pai->getArguments()) {
19571954
// Retain the argument since it's to be owned by the newly created
19581955
// closure.
1959-
if (arg->getType().isObject())
1956+
if (arg->getType().isObject()) {
19601957
builder.createRetainValue(loc, arg, builder.getDefaultAtomicity());
1961-
else if (arg->getType().isLoadable(builder.getFunction()))
1958+
newArgs.push_back(arg);
1959+
} else if (arg->getType().isLoadable(builder.getFunction())) {
19621960
builder.createRetainValueAddr(loc, arg, builder.getDefaultAtomicity());
1963-
newArgs.push_back(arg);
1961+
newArgs.push_back(arg);
1962+
} else {
1963+
auto *argCopy = builder.createAllocStack(loc, arg->getType());
1964+
copiedIndirectParams.push_back(argCopy);
1965+
builder.createCopyAddr(loc, arg, argCopy, IsNotTake, IsInitialization);
1966+
newArgs.push_back(argCopy);
1967+
}
19641968
}
19651969
auto innerNewFunc = reapplyFunctionConversion(
19661970
newFunc, oldFunc, pai->getCallee(), builder, loc, newFuncGenSig);
@@ -3395,8 +3399,7 @@ class VJPEmitter final
33953399

33963400
// Release the differentiable function.
33973401
if (differentiableFunc)
3398-
builder.createReleaseValue(loc, differentiableFunc,
3399-
builder.getDefaultAtomicity());
3402+
builder.emitReleaseValueAndFold(loc, differentiableFunc);
34003403

34013404
// Get the VJP results (original results and pullback).
34023405
SmallVector<SILValue, 8> vjpDirectResults;
@@ -3962,7 +3965,7 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
39623965
LLVM_DEBUG(getADDebugStream() << "Cleaning up temporaries for bb"
39633966
<< bb->getDebugID() << '\n');
39643967
for (auto temp : blockTemporaries[bb]) {
3965-
emitCleanup(builder, loc, temp);
3968+
builder.emitReleaseValueAndFold(loc, temp);
39663969
blockTemporarySet.erase(temp);
39673970
}
39683971
}
@@ -4534,9 +4537,6 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
45344537
auto adjBuf = getAdjointBuffer(origEntry, origParam);
45354538
if (errorOccurred)
45364539
return;
4537-
if (adjBuf->getType().isLoadable(getPullback()))
4538-
builder.createRetainValueAddr(pbLoc, adjBuf,
4539-
builder.getDefaultAtomicity());
45404540
indParamAdjoints.push_back(adjBuf);
45414541
}
45424542
};
@@ -4837,7 +4837,7 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
48374837
ApplyInst *ai, CopyAddrInst *cai, AllocStackInst *subscriptBuffer) {
48384838
addToAdjointBuffer(cai->getParent(), cai->getSrc(), subscriptBuffer,
48394839
cai->getLoc());
4840-
emitCleanup(builder, cai->getLoc(), subscriptBuffer);
4840+
builder.emitDestroyAddrAndFold(cai->getLoc(), subscriptBuffer);
48414841
builder.createDeallocStack(ai->getLoc(), subscriptBuffer);
48424842
}
48434843

@@ -5033,7 +5033,7 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
50335033
auto tan = *allResultsIt++;
50345034
if (tan->getType().isAddress()) {
50355035
addToAdjointBuffer(bb, origArg, tan, loc);
5036-
emitCleanup(builder, loc, tan);
5036+
builder.emitDestroyAddrAndFold(loc, tan);
50375037
} else {
50385038
if (origArg->getType().isAddress()) {
50395039
if (errorOccurred)
@@ -5042,7 +5042,7 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
50425042
builder.createStore(loc, tan, tmpBuf,
50435043
getBufferSOQ(tmpBuf->getType().getASTType(), getPullback()));
50445044
addToAdjointBuffer(bb, origArg, tmpBuf, loc);
5045-
emitCleanup(builder, loc, tmpBuf);
5045+
builder.emitDestroyAddrAndFold(loc, tmpBuf);
50465046
builder.createDeallocStack(loc, tmpBuf);
50475047
}
50485048
else {
@@ -5338,7 +5338,7 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
53385338
return;
53395339
auto destType = remapType(adjDest->getType());
53405340
addToAdjointBuffer(bb, cai->getSrc(), adjDest, cai->getLoc());
5341-
emitCleanup(builder, cai->getLoc(), adjDest);
5341+
builder.emitDestroyAddrAndFold(cai->getLoc(), adjDest);
53425342
emitZeroIndirect(destType.getASTType(), adjDest, cai->getLoc());
53435343
}
53445344

@@ -6297,13 +6297,15 @@ ADContext::getOrCreateSubsetParametersThunkForLinearMap(
62976297
// - Push it to `results` if result is direct.
62986298
auto result = allResults[mapOriginalParameterIndex(i)];
62996299
if (desiredIndices.isWrtParameter(i)) {
6300-
if (result->getType().isAddress())
6301-
continue;
6302-
results.push_back(result);
6300+
if (result->getType().isObject())
6301+
results.push_back(result);
63036302
}
63046303
// Otherwise, cleanup the unused results.
63056304
else {
6306-
emitCleanup(builder, loc, result);
6305+
if (result->getType().isAddress())
6306+
builder.emitDestroyAddrAndFold(loc, result);
6307+
else
6308+
builder.emitReleaseValueAndFold(loc, result);
63076309
}
63086310
}
63096311
// Deallocate local allocations and return final direct result.
@@ -6676,7 +6678,7 @@ void ADContext::foldAutoDiffFunctionExtraction(AutoDiffFunctionInst *source) {
66766678
if (isInstructionTriviallyDead(source)) {
66776679
SILBuilder builder(source);
66786680
for (auto &assocFn : source->getAssociatedFunctions())
6679-
emitCleanup(builder, source->getLoc(), assocFn.get());
6681+
builder.emitDestroyAddrAndFold(source->getLoc(), assocFn.get());
66806682
source->eraseFromParent();
66816683
}
66826684
// Mark `source` as processed so that it won't be reprocessed after deletion.

test/AutoDiff/witness_method_autodiff.sil

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,18 @@ bb0(%0 : $*T):
3838
}
3939

4040
// CHECK-LABEL: sil @differentiatePartiallyAppliedWitnessMethod
41+
// CHECK: bb0([[ARG:%.*]] : $*T):
4142
// CHECK: [[ORIG_REF:%.*]] = witness_method $T, #DiffReq.f!1
4243
// CHECK: [[ORIG_REF_PARTIALLY_APPLIED:%.*]] = partial_apply [callee_guaranteed] [[ORIG_REF]]<T>(%0)
4344
// CHECK: [[JVP_REF:%.*]] = witness_method $T, #DiffReq.f!1.jvp.1.SU
44-
// CHECK: [[JVP_REF_PARTIALLY_APPLIED:%.*]] = partial_apply [callee_guaranteed] [[JVP_REF]]<T>(%0)
45+
// CHECK: [[ARGCOPY1:%.*]] = alloc_stack $T
46+
// CHECK: copy_addr [[ARG]] to [initialization] [[ARGCOPY1]] : $*T
47+
// CHECK: [[JVP_REF_PARTIALLY_APPLIED:%.*]] = partial_apply [callee_guaranteed] [[JVP_REF]]<T>([[ARGCOPY1]])
48+
// CHECK: dealloc_stack [[ARGCOPY1]]
4549
// CHECK: [[VJP_REF:%.*]] = witness_method $T, #DiffReq.f!1.vjp.1.SU
46-
// CHECK: [[VJP_REF_PARTIALLY_APPLIED:%.*]] = partial_apply [callee_guaranteed] [[VJP_REF]]<T>(%0)
47-
// CHECK: = autodiff_function [wrt 0] [order 1] [[ORIG_REF_PARTIALLY_APPLIED]] : {{.*}} with {[[JVP_REF_PARTIALLY_APPLIED]] : {{.*}}, [[VJP_REF_PARTIALLY_APPLIED]] : {{.*}}}
50+
// CHECK: [[ARGCOPY2:%.*]] = alloc_stack $T
51+
// CHECK: copy_addr [[ARG]] to [initialization] [[ARGCOPY2]] : $*T
52+
// CHECK: [[VJP_REF_PARTIALLY_APPLIED:%.*]] = partial_apply [callee_guaranteed] [[VJP_REF]]<T>([[ARGCOPY2]])
53+
// CHECK: dealloc_stack [[ARGCOPY2]]
54+
// CHECK: autodiff_function [wrt 0] [order 1] [[ORIG_REF_PARTIALLY_APPLIED]] : {{.*}} with {[[JVP_REF_PARTIALLY_APPLIED]] : {{.*}}, [[VJP_REF_PARTIALLY_APPLIED]] : {{.*}}}
4855
// CHECK: } // end sil function 'differentiatePartiallyAppliedWitnessMethod'

0 commit comments

Comments
 (0)