Skip to content

Commit 8990a12

Browse files
authored
Fix use after free when pullback is used multiple times. (#64647)
Linear maps are captured in vjp routine via callee-guaranteed partial apply and are passed as @owned references to the enclosing pullback that finally consumes them. Necessary retains are inserted by a partial apply forwarder. However, this is not the case when the function being differentiated contains loops as heap-allocated context is used and bare pointer is captured by the pullback partial apply. As a result, partial apply forwarder does not retain the linear maps that are owned by a heap-allocated context, however, they are still treated as @owned references and therefore are released in the pullback after the first call. As a result, subsequent pullback calls release linear maps and we'd end with possible use-after-free. Ensure we retain values when we load values from the context. Reproducible only when: * Loops (so, heap-allocated context) * Pullbacks of thick functions (so context is non-zero) * Multiple pullback calls * Some cleanup while there Fixes #64257
1 parent 8ac71b1 commit 8990a12

File tree

5 files changed

+64
-19
lines changed

5 files changed

+64
-19
lines changed

include/swift/SILOptimizer/Differentiation/LinearMapInfo.h

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -86,15 +86,16 @@ class LinearMapInfo {
8686
llvm::DenseMap<std::pair<SILBasicBlock *, SILBasicBlock *>, EnumElementDecl *>
8787
branchingTraceEnumCases;
8888

89-
/// Blocks in a loop.
90-
llvm::SmallSetVector<SILBasicBlock *, 4> blocksInLoop;
91-
9289
/// A synthesized file unit.
9390
SynthesizedFileUnit &synthesizedFile;
9491

9592
/// A type converter, used to compute struct/enum SIL types.
9693
Lowering::TypeConverter &typeConverter;
9794

95+
/// True, if a heap-allocated context is required. For example, when there are
96+
/// any loops
97+
bool heapAllocatedContext = false;
98+
9899
private:
99100
/// Remaps the given type into the derivative function's context.
100101
SILType remapTypeInDerivative(SILType ty);
@@ -193,12 +194,8 @@ class LinearMapInfo {
193194
return getLinearMapTupleType(ai->getParentBlock())->getElement(idx).getType();
194195
}
195196

196-
bool hasLoops() const {
197-
return !blocksInLoop.empty();
198-
}
199-
200-
ArrayRef<SILBasicBlock *> getBlocksInLoop() const {
201-
return blocksInLoop.getArrayRef();
197+
bool hasHeapAllocatedContext() const {
198+
return heapAllocatedContext;
202199
}
203200
};
204201

lib/SILOptimizer/Differentiation/LinearMapInfo.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -138,11 +138,9 @@ void LinearMapInfo::populateBranchingTraceDecl(SILBasicBlock *originalBB,
138138
// indirectly referenced in memory owned by the context object. The payload
139139
// is just a raw pointer.
140140
if (loopInfo->getLoopFor(predBB)) {
141-
blocksInLoop.insert(predBB);
141+
heapAllocatedContext = true;
142142
decl->setInterfaceType(astCtx.TheRawPointerType);
143-
}
144-
// Otherwise the payload is the linear map tuple.
145-
else {
143+
} else { // Otherwise the payload is the linear map tuple.
146144
auto linearMapStructTy = getLinearMapTupleType(predBB)->getCanonicalType();
147145
decl->setInterfaceType(
148146
linearMapStructTy->hasArchetype()

lib/SILOptimizer/Differentiation/PullbackCloner.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1928,7 +1928,7 @@ bool PullbackCloner::Implementation::run() {
19281928
builder.setInsertionPoint(pullbackBB);
19291929
// Obtain the context object, if any, and the top-level subcontext, i.e.
19301930
// the main pullback struct.
1931-
if (getPullbackInfo().hasLoops()) {
1931+
if (getPullbackInfo().hasHeapAllocatedContext()) {
19321932
// The last argument is the context object (`Builtin.NativeObject`).
19331933
contextValue = pullbackBB->getArguments().back();
19341934
assert(contextValue->getType() ==
@@ -1939,7 +1939,7 @@ bool PullbackCloner::Implementation::run() {
19391939
SILValue mainPullbackTuple = builder.createLoad(
19401940
pbLoc, subcontextAddr,
19411941
pbTupleLoweredType.isTrivial(getPullback()) ?
1942-
LoadOwnershipQualifier::Trivial : LoadOwnershipQualifier::Take);
1942+
LoadOwnershipQualifier::Trivial : LoadOwnershipQualifier::Copy);
19431943
auto *dsi = builder.createDestructureTuple(pbLoc, mainPullbackTuple);
19441944
initializePullbackTupleElements(origBB, dsi->getAllResults());
19451945
} else {
@@ -2023,7 +2023,7 @@ bool PullbackCloner::Implementation::run() {
20232023
auto *pullbackEntry = pullback.getEntryBlock();
20242024
auto pbTupleLoweredType =
20252025
remapType(getPullbackInfo().getLinearMapTupleLoweredType(originalExitBlock));
2026-
unsigned numVals = (getPullbackInfo().hasLoops() ?
2026+
unsigned numVals = (getPullbackInfo().hasHeapAllocatedContext() ?
20272027
1 : pbTupleLoweredType.getAs<TupleType>()->getNumElements());
20282028
(void)numVals;
20292029

@@ -2449,7 +2449,7 @@ SILBasicBlock *PullbackCloner::Implementation::buildPullbackSuccessor(
24492449
auto predPbStructVal = pullbackTrampolineBBBuilder.createLoad(
24502450
loc, predPbTupleAddr,
24512451
pbTupleType.isTrivial(getPullback()) ?
2452-
LoadOwnershipQualifier::Trivial : LoadOwnershipQualifier::Take);
2452+
LoadOwnershipQualifier::Trivial : LoadOwnershipQualifier::Copy);
24532453
trampolineArguments.push_back(predPbStructVal);
24542454
} else {
24552455
trampolineArguments.push_back(pullbackTrampolineBBArg);

lib/SILOptimizer/Differentiation/VJPCloner.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ class VJPCloner::Implementation final
110110

111111
/// Initializes a context object if needed.
112112
void emitLinearMapContextInitializationIfNeeded() {
113-
if (!pullbackInfo.hasLoops())
113+
if (!pullbackInfo.hasHeapAllocatedContext())
114114
return;
115115

116116
// Get linear map struct size.
@@ -937,7 +937,7 @@ SILFunction *VJPCloner::Implementation::createEmptyPullback() {
937937
pbParams.push_back(inoutParamTanParam);
938938
}
939939

940-
if (pullbackInfo.hasLoops()) {
940+
if (pullbackInfo.hasHeapAllocatedContext()) {
941941
// Accept a `AutoDiffLinarMapContext` heap object if there are loops.
942942
pbParams.push_back({
943943
getASTContext().TheNativeObjectType,
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
// RUN: %target-run-simple-swift
2+
// REQUIRES: executable_test
3+
4+
import _Differentiation
5+
public func simWithPullback(params: MinParams) -> (value: Output, pullback: (Output.TangentVector) -> (MinParams.TangentVector)){
6+
let simulationValueAndPullback = valueWithPullback(at: params, of: run)
7+
return (value: simulationValueAndPullback.value, pullback: simulationValueAndPullback.pullback)
8+
}
9+
10+
@differentiable(reverse)
11+
public func run(params: MinParams) -> Output {
12+
for t in 0 ... 1 {
13+
}
14+
15+
let res = MiniLoop(other: params._twoDArray).results
16+
return Output(results: res)
17+
}
18+
19+
struct MiniLoop: Differentiable {
20+
var results: Float
21+
var twoDArray: Float
22+
23+
@differentiable(reverse)
24+
init(results: Float = 146, other: Float = 123) {self.results = results; self.twoDArray = other}
25+
}
26+
27+
public struct Output: Differentiable {
28+
public var results: Float
29+
@differentiable(reverse)
30+
public init(results: Float) {self.results = results}
31+
}
32+
33+
public struct MinParams: Differentiable {
34+
public var _twoDArray: Float
35+
public init(foo: Float = 42) { _twoDArray = foo }
36+
}
37+
38+
func test() {
39+
let valueAndPullback = simWithPullback(params: MinParams())
40+
let output = valueAndPullback.value
41+
let resultOnes = Float(1.0)
42+
var grad = valueAndPullback.pullback(Output.TangentVector(results: resultOnes))
43+
print(grad)
44+
grad = valueAndPullback.pullback(Output.TangentVector(results: resultOnes))
45+
print(grad)
46+
grad = valueAndPullback.pullback(Output.TangentVector(results: resultOnes))
47+
print(grad)
48+
}
49+
50+
test()

0 commit comments

Comments
 (0)