Skip to content

[AutoDiff] Fix use after free when pullback is used multiple times #64647

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 6 additions & 9 deletions include/swift/SILOptimizer/Differentiation/LinearMapInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,15 +86,16 @@ class LinearMapInfo {
llvm::DenseMap<std::pair<SILBasicBlock *, SILBasicBlock *>, EnumElementDecl *>
branchingTraceEnumCases;

/// Blocks in a loop.
llvm::SmallSetVector<SILBasicBlock *, 4> blocksInLoop;

/// A synthesized file unit.
SynthesizedFileUnit &synthesizedFile;

/// A type converter, used to compute struct/enum SIL types.
Lowering::TypeConverter &typeConverter;

/// True, if a heap-allocated context is required. For example, when there are
/// any loops
bool heapAllocatedContext = false;

private:
/// Remaps the given type into the derivative function's context.
SILType remapTypeInDerivative(SILType ty);
Expand Down Expand Up @@ -193,12 +194,8 @@ class LinearMapInfo {
return getLinearMapTupleType(ai->getParentBlock())->getElement(idx).getType();
}

bool hasLoops() const {
return !blocksInLoop.empty();
}

ArrayRef<SILBasicBlock *> getBlocksInLoop() const {
return blocksInLoop.getArrayRef();
bool hasHeapAllocatedContext() const {
return heapAllocatedContext;
}
};

Expand Down
6 changes: 2 additions & 4 deletions lib/SILOptimizer/Differentiation/LinearMapInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,11 +138,9 @@ void LinearMapInfo::populateBranchingTraceDecl(SILBasicBlock *originalBB,
// indirectly referenced in memory owned by the context object. The payload
// is just a raw pointer.
if (loopInfo->getLoopFor(predBB)) {
blocksInLoop.insert(predBB);
heapAllocatedContext = true;
decl->setInterfaceType(astCtx.TheRawPointerType);
}
// Otherwise the payload is the linear map tuple.
else {
} else { // Otherwise the payload is the linear map tuple.
auto linearMapStructTy = getLinearMapTupleType(predBB)->getCanonicalType();
decl->setInterfaceType(
linearMapStructTy->hasArchetype()
Expand Down
8 changes: 4 additions & 4 deletions lib/SILOptimizer/Differentiation/PullbackCloner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1928,7 +1928,7 @@ bool PullbackCloner::Implementation::run() {
builder.setInsertionPoint(pullbackBB);
// Obtain the context object, if any, and the top-level subcontext, i.e.
// the main pullback struct.
if (getPullbackInfo().hasLoops()) {
if (getPullbackInfo().hasHeapAllocatedContext()) {
// The last argument is the context object (`Builtin.NativeObject`).
contextValue = pullbackBB->getArguments().back();
assert(contextValue->getType() ==
Expand All @@ -1939,7 +1939,7 @@ bool PullbackCloner::Implementation::run() {
SILValue mainPullbackTuple = builder.createLoad(
pbLoc, subcontextAddr,
pbTupleLoweredType.isTrivial(getPullback()) ?
LoadOwnershipQualifier::Trivial : LoadOwnershipQualifier::Take);
LoadOwnershipQualifier::Trivial : LoadOwnershipQualifier::Copy);
auto *dsi = builder.createDestructureTuple(pbLoc, mainPullbackTuple);
initializePullbackTupleElements(origBB, dsi->getAllResults());
} else {
Expand Down Expand Up @@ -2023,7 +2023,7 @@ bool PullbackCloner::Implementation::run() {
auto *pullbackEntry = pullback.getEntryBlock();
auto pbTupleLoweredType =
remapType(getPullbackInfo().getLinearMapTupleLoweredType(originalExitBlock));
unsigned numVals = (getPullbackInfo().hasLoops() ?
unsigned numVals = (getPullbackInfo().hasHeapAllocatedContext() ?
1 : pbTupleLoweredType.getAs<TupleType>()->getNumElements());
(void)numVals;

Expand Down Expand Up @@ -2449,7 +2449,7 @@ SILBasicBlock *PullbackCloner::Implementation::buildPullbackSuccessor(
auto predPbStructVal = pullbackTrampolineBBBuilder.createLoad(
loc, predPbTupleAddr,
pbTupleType.isTrivial(getPullback()) ?
LoadOwnershipQualifier::Trivial : LoadOwnershipQualifier::Take);
LoadOwnershipQualifier::Trivial : LoadOwnershipQualifier::Copy);
trampolineArguments.push_back(predPbStructVal);
} else {
trampolineArguments.push_back(pullbackTrampolineBBArg);
Expand Down
4 changes: 2 additions & 2 deletions lib/SILOptimizer/Differentiation/VJPCloner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ class VJPCloner::Implementation final

/// Initializes a context object if needed.
void emitLinearMapContextInitializationIfNeeded() {
if (!pullbackInfo.hasLoops())
if (!pullbackInfo.hasHeapAllocatedContext())
return;

// Get linear map struct size.
Expand Down Expand Up @@ -937,7 +937,7 @@ SILFunction *VJPCloner::Implementation::createEmptyPullback() {
pbParams.push_back(inoutParamTanParam);
}

if (pullbackInfo.hasLoops()) {
if (pullbackInfo.hasHeapAllocatedContext()) {
// Accept a `AutoDiffLinarMapContext` heap object if there are loops.
pbParams.push_back({
getASTContext().TheNativeObjectType,
Expand Down
50 changes: 50 additions & 0 deletions test/AutoDiff/validation-test/issue-64257-use-after-free.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
// RUN: %target-run-simple-swift
// REQUIRES: executable_test

import _Differentiation
public func simWithPullback(params: MinParams) -> (value: Output, pullback: (Output.TangentVector) -> (MinParams.TangentVector)){
let simulationValueAndPullback = valueWithPullback(at: params, of: run)
return (value: simulationValueAndPullback.value, pullback: simulationValueAndPullback.pullback)
}

@differentiable(reverse)
public func run(params: MinParams) -> Output {
for t in 0 ... 1 {
}

let res = MiniLoop(other: params._twoDArray).results
return Output(results: res)
}

struct MiniLoop: Differentiable {
var results: Float
var twoDArray: Float

@differentiable(reverse)
init(results: Float = 146, other: Float = 123) {self.results = results; self.twoDArray = other}
}

public struct Output: Differentiable {
public var results: Float
@differentiable(reverse)
public init(results: Float) {self.results = results}
}

public struct MinParams: Differentiable {
public var _twoDArray: Float
public init(foo: Float = 42) { _twoDArray = foo }
}

func test() {
let valueAndPullback = simWithPullback(params: MinParams())
let output = valueAndPullback.value
let resultOnes = Float(1.0)
var grad = valueAndPullback.pullback(Output.TangentVector(results: resultOnes))
print(grad)
grad = valueAndPullback.pullback(Output.TangentVector(results: resultOnes))
print(grad)
grad = valueAndPullback.pullback(Output.TangentVector(results: resultOnes))
print(grad)
}

test()