Skip to content

[AutoDiff] Generate transparent ossa reabstraction thunks. #33854

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 99 additions & 25 deletions lib/SILOptimizer/Differentiation/Thunk.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,58 @@ CanSILFunctionType buildThunkType(SILFunction *fn,
fn->getASTContext());
}

/// Forward function arguments, handling ownership convention mismatches.
/// Adapted from `forwardFunctionArguments` in SILGenPoly.cpp.
///
/// Forwarded arguments are appended to `forwardedArgs`.
///
/// Local allocations are appended to `localAllocations`. They need to be
/// deallocated via `dealloc_stack`.
///
/// Local values requiring cleanup are appended to `valuesToCleanup`.
static void forwardFunctionArgumentsConvertingOwnership(
SILBuilder &builder, SILLocation loc, CanSILFunctionType fromTy,
CanSILFunctionType toTy, ArrayRef<SILArgument *> originalArgs,
SmallVectorImpl<SILValue> &forwardedArgs,
SmallVectorImpl<AllocStackInst *> &localAllocations,
SmallVectorImpl<SILValue> &valuesToCleanup) {
auto fromParameters = fromTy->getParameters();
auto toParameters = toTy->getParameters();
assert(fromParameters.size() == toParameters.size());
assert(fromParameters.size() == originalArgs.size());
for (auto index : indices(originalArgs)) {
auto &arg = originalArgs[index];
auto fromParam = fromParameters[index];
auto toParam = toParameters[index];
// To convert guaranteed argument to be owned, create a copy.
if (fromParam.isConsumed() && !toParam.isConsumed()) {
// If the argument has an object type, create a `copy_value`.
if (arg->getType().isObject()) {
auto argCopy = builder.emitCopyValueOperation(loc, arg);
forwardedArgs.push_back(argCopy);
continue;
}
// If the argument has an address type, create a local allocation and
// `copy_addr` its contents to the local allocation.
auto *alloc = builder.createAllocStack(loc, arg->getType());
builder.createCopyAddr(loc, arg, alloc, IsNotTake, IsInitialization);
localAllocations.push_back(alloc);
forwardedArgs.push_back(alloc);
continue;
}
// To convert owned argument to be guaranteed, borrow the argument.
if (fromParam.isGuaranteed() && !toParam.isGuaranteed()) {
auto bbi = builder.emitBeginBorrowOperation(loc, arg);
forwardedArgs.push_back(bbi);
valuesToCleanup.push_back(bbi);
valuesToCleanup.push_back(arg);
continue;
}
// Otherwise, simply forward the argument.
forwardedArgs.push_back(arg);
}
}

SILFunction *getOrCreateReabstractionThunk(SILOptFunctionBuilder &fb,
SILModule &module, SILLocation loc,
SILFunction *caller,
Expand All @@ -274,18 +326,13 @@ SILFunction *getOrCreateReabstractionThunk(SILOptFunctionBuilder &fb,
thunkType, fromInterfaceType, toInterfaceType, Type(),
module.getSwiftModule());

// FIXME(TF-989): Mark reabstraction thunks as transparent. This requires
// generating ossa reabstraction thunks so that they can be inlined during
// mandatory inlining when `-enable-strip-ownership-after-serialization` is
// true and ownership model eliminator is not run after differentiation.
auto *thunk = fb.getOrCreateSharedFunction(
loc, name, thunkDeclType, IsBare, IsNotTransparent, IsSerialized,
loc, name, thunkDeclType, IsBare, IsTransparent, IsSerialized,
ProfileCounter(), IsReabstractionThunk, IsNotDynamic);
if (!thunk->empty())
return thunk;

thunk->setGenericEnvironment(genericEnv);
thunk->setOwnershipEliminated();
auto *entry = thunk->createBasicBlock();
SILBuilder builder(entry);
createEntryArguments(thunk);
Expand All @@ -294,13 +341,21 @@ SILFunction *getOrCreateReabstractionThunk(SILOptFunctionBuilder &fb,
SILFunctionConventions toConv(toType, module);
assert(toConv.useLoweredAddresses());

auto *fnArg = thunk->getArgumentsWithoutIndirectResults().back();
// Forward thunk arguments, handling ownership convention mismatches.
SmallVector<SILValue, 4> forwardedArgs;
for (auto indRes : thunk->getIndirectResults())
forwardedArgs.push_back(indRes);
SmallVector<AllocStackInst *, 4> localAllocations;
SmallVector<SILValue, 4> valuesToCleanup;
forwardFunctionArgumentsConvertingOwnership(
builder, loc, fromType, toType,
thunk->getArgumentsWithoutIndirectResults().drop_back(), forwardedArgs,
localAllocations, valuesToCleanup);

SmallVector<SILValue, 4> arguments;
auto toArgIter = thunk->getArguments().begin();
auto toArgIter = forwardedArgs.begin();
auto useNextArgument = [&]() { arguments.push_back(*toArgIter++); };

SmallVector<AllocStackInst *, 4> localAllocations;
auto createAllocStack = [&](SILType type) {
auto *alloc = builder.createAllocStack(loc, type);
localAllocations.push_back(alloc);
Expand Down Expand Up @@ -350,21 +405,25 @@ SILFunction *getOrCreateReabstractionThunk(SILOptFunctionBuilder &fb,
if (!paramTy.hasArchetype())
paramTy = thunk->mapTypeIntoContext(paramTy);
assert(paramTy.isAddress());
auto *toArg = *toArgIter++;
auto toArg = *toArgIter++;
auto *buf = createAllocStack(toArg->getType());
builder.createStore(loc, toArg, buf,
StoreOwnershipQualifier::Unqualified);
toArg = builder.emitCopyValueOperation(loc, toArg);
builder.emitStoreValueOperation(loc, toArg, buf,
StoreOwnershipQualifier::Init);
valuesToCleanup.push_back(buf);
arguments.push_back(buf);
continue;
}
// Convert direct parameter to indirect parameter.
assert(toParam.isFormalIndirect());
auto *toArg = *toArgIter++;
auto *load =
builder.createLoad(loc, toArg, LoadOwnershipQualifier::Unqualified);
auto toArg = *toArgIter++;
auto load = builder.emitLoadBorrowOperation(loc, toArg);
if (isa<LoadBorrowInst>(load))
valuesToCleanup.push_back(load);
arguments.push_back(load);
}

auto *fnArg = thunk->getArgumentsWithoutIndirectResults().back();
auto *apply = builder.createApply(loc, fnArg, SubstitutionMap(), arguments,
/*isNonThrowing*/ false);

Expand Down Expand Up @@ -413,8 +472,8 @@ SILFunction *getOrCreateReabstractionThunk(SILOptFunctionBuilder &fb,
// Load direct results from indirect results.
if (fromRes.isFormalIndirect()) {
auto indRes = *fromIndResultsIter++;
auto *load =
builder.createLoad(loc, indRes, LoadOwnershipQualifier::Unqualified);
auto load = builder.emitLoadValueOperation(loc, indRes,
LoadOwnershipQualifier::Take);
results.push_back(load);
continue;
}
Expand All @@ -426,11 +485,28 @@ SILFunction *getOrCreateReabstractionThunk(SILOptFunctionBuilder &fb,
assert(resultTy.isAddress());
#endif
auto indRes = *toIndResultsIter++;
builder.createStore(loc, *fromDirResultsIter++, indRes,
StoreOwnershipQualifier::Unqualified);
auto dirRes = *fromDirResultsIter++;
builder.emitStoreValueOperation(loc, dirRes, indRes,
StoreOwnershipQualifier::Init);
}
auto retVal = joinElements(results, builder, loc);

// Clean up local values.
// Guaranteed values need an `end_borrow`.
// Owned values need to be destroyed.
for (auto arg : valuesToCleanup) {
switch (arg.getOwnershipKind()) {
case ValueOwnershipKind::Guaranteed:
builder.emitEndBorrowOperation(loc, arg);
break;
case ValueOwnershipKind::Owned:
case ValueOwnershipKind::Unowned:
case ValueOwnershipKind::None:
builder.emitDestroyOperation(loc, arg);
break;
}
}

// Deallocate local allocations.
for (auto *alloc : llvm::reverse(localAllocations))
builder.createDeallocStack(loc, alloc);
Expand Down Expand Up @@ -549,11 +625,11 @@ getOrCreateSubsetParametersThunkForLinearMap(
auto *buf = builder.createAllocStack(loc, zeroSILObjType);
localAllocations.push_back(buf);
emitZeroIntoBuffer(builder, zeroType, buf, loc);
if (zeroSILType.isAddress())
if (zeroSILType.isAddress()) {
arguments.push_back(buf);
else {
auto *arg =
builder.createLoad(loc, buf, LoadOwnershipQualifier::Unqualified);
} else {
auto arg = builder.emitLoadValueOperation(loc, buf,
LoadOwnershipQualifier::Take);
arguments.push_back(arg);
}
break;
Expand Down Expand Up @@ -810,8 +886,6 @@ getOrCreateSubsetParametersThunkForDerivativeFunction(
if (!thunk->empty())
return {thunk, interfaceSubs};

// TODO(TF-1206): Enable ownership in all differentiation thunks.
thunk->setOwnershipEliminated();
thunk->setGenericEnvironment(genericEnv);
auto *entry = thunk->createBasicBlock();
SILBuilder builder(entry);
Expand Down