Skip to content
This repository was archived by the owner on Jan 10, 2023. It is now read-only.

Commit 49652f0

Browse files
authored
[AutoDiff] Generate transparent ossa reabstraction thunks. (swiftlang#33897)
Make the differentiation transform generate transparent, ossa reabstraction thunks. This enables these thunks to be inlined into other ossa functions (e.g. generated VJP and pullback functions) during mandatory inlining. Resolves TF-989. Unblocks further autodiff-related optimizations: SR-13390.
1 parent ce56277 commit 49652f0

File tree

1 file changed

+99
-25
lines changed

1 file changed

+99
-25
lines changed

lib/SILOptimizer/Differentiation/Thunk.cpp

Lines changed: 99 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,58 @@ CanSILFunctionType buildThunkType(SILFunction *fn,
249249
fn->getASTContext());
250250
}
251251

252+
/// Forward function arguments, handling ownership convention mismatches.
253+
/// Adapted from `forwardFunctionArguments` in SILGenPoly.cpp.
254+
///
255+
/// Forwarded arguments are appended to `forwardedArgs`.
256+
///
257+
/// Local allocations are appended to `localAllocations`. They need to be
258+
/// deallocated via `dealloc_stack`.
259+
///
260+
/// Local values requiring cleanup are appended to `valuesToCleanup`.
261+
static void forwardFunctionArgumentsConvertingOwnership(
262+
SILBuilder &builder, SILLocation loc, CanSILFunctionType fromTy,
263+
CanSILFunctionType toTy, ArrayRef<SILArgument *> originalArgs,
264+
SmallVectorImpl<SILValue> &forwardedArgs,
265+
SmallVectorImpl<AllocStackInst *> &localAllocations,
266+
SmallVectorImpl<SILValue> &valuesToCleanup) {
267+
auto fromParameters = fromTy->getParameters();
268+
auto toParameters = toTy->getParameters();
269+
assert(fromParameters.size() == toParameters.size());
270+
assert(fromParameters.size() == originalArgs.size());
271+
for (auto index : indices(originalArgs)) {
272+
auto &arg = originalArgs[index];
273+
auto fromParam = fromParameters[index];
274+
auto toParam = toParameters[index];
275+
// To convert guaranteed argument to be owned, create a copy.
276+
if (fromParam.isConsumed() && !toParam.isConsumed()) {
277+
// If the argument has an object type, create a `copy_value`.
278+
if (arg->getType().isObject()) {
279+
auto argCopy = builder.emitCopyValueOperation(loc, arg);
280+
forwardedArgs.push_back(argCopy);
281+
continue;
282+
}
283+
// If the argument has an address type, create a local allocation and
284+
// `copy_addr` its contents to the local allocation.
285+
auto *alloc = builder.createAllocStack(loc, arg->getType());
286+
builder.createCopyAddr(loc, arg, alloc, IsNotTake, IsInitialization);
287+
localAllocations.push_back(alloc);
288+
forwardedArgs.push_back(alloc);
289+
continue;
290+
}
291+
// To convert owned argument to be guaranteed, borrow the argument.
292+
if (fromParam.isGuaranteed() && !toParam.isGuaranteed()) {
293+
auto bbi = builder.emitBeginBorrowOperation(loc, arg);
294+
forwardedArgs.push_back(bbi);
295+
valuesToCleanup.push_back(bbi);
296+
valuesToCleanup.push_back(arg);
297+
continue;
298+
}
299+
// Otherwise, simply forward the argument.
300+
forwardedArgs.push_back(arg);
301+
}
302+
}
303+
252304
SILFunction *getOrCreateReabstractionThunk(SILOptFunctionBuilder &fb,
253305
SILModule &module, SILLocation loc,
254306
SILFunction *caller,
@@ -274,18 +326,13 @@ SILFunction *getOrCreateReabstractionThunk(SILOptFunctionBuilder &fb,
274326
thunkType, fromInterfaceType, toInterfaceType, Type(),
275327
module.getSwiftModule());
276328

277-
// FIXME(TF-989): Mark reabstraction thunks as transparent. This requires
278-
// generating ossa reabstraction thunks so that they can be inlined during
279-
// mandatory inlining when `-enable-strip-ownership-after-serialization` is
280-
// true and ownership model eliminator is not run after differentiation.
281329
auto *thunk = fb.getOrCreateSharedFunction(
282-
loc, name, thunkDeclType, IsBare, IsNotTransparent, IsSerialized,
330+
loc, name, thunkDeclType, IsBare, IsTransparent, IsSerialized,
283331
ProfileCounter(), IsReabstractionThunk, IsNotDynamic);
284332
if (!thunk->empty())
285333
return thunk;
286334

287335
thunk->setGenericEnvironment(genericEnv);
288-
thunk->setOwnershipEliminated();
289336
auto *entry = thunk->createBasicBlock();
290337
SILBuilder builder(entry);
291338
createEntryArguments(thunk);
@@ -294,13 +341,21 @@ SILFunction *getOrCreateReabstractionThunk(SILOptFunctionBuilder &fb,
294341
SILFunctionConventions toConv(toType, module);
295342
assert(toConv.useLoweredAddresses());
296343

297-
auto *fnArg = thunk->getArgumentsWithoutIndirectResults().back();
344+
// Forward thunk arguments, handling ownership convention mismatches.
345+
SmallVector<SILValue, 4> forwardedArgs;
346+
for (auto indRes : thunk->getIndirectResults())
347+
forwardedArgs.push_back(indRes);
348+
SmallVector<AllocStackInst *, 4> localAllocations;
349+
SmallVector<SILValue, 4> valuesToCleanup;
350+
forwardFunctionArgumentsConvertingOwnership(
351+
builder, loc, fromType, toType,
352+
thunk->getArgumentsWithoutIndirectResults().drop_back(), forwardedArgs,
353+
localAllocations, valuesToCleanup);
298354

299355
SmallVector<SILValue, 4> arguments;
300-
auto toArgIter = thunk->getArguments().begin();
356+
auto toArgIter = forwardedArgs.begin();
301357
auto useNextArgument = [&]() { arguments.push_back(*toArgIter++); };
302358

303-
SmallVector<AllocStackInst *, 4> localAllocations;
304359
auto createAllocStack = [&](SILType type) {
305360
auto *alloc = builder.createAllocStack(loc, type);
306361
localAllocations.push_back(alloc);
@@ -350,21 +405,25 @@ SILFunction *getOrCreateReabstractionThunk(SILOptFunctionBuilder &fb,
350405
if (!paramTy.hasArchetype())
351406
paramTy = thunk->mapTypeIntoContext(paramTy);
352407
assert(paramTy.isAddress());
353-
auto *toArg = *toArgIter++;
408+
auto toArg = *toArgIter++;
354409
auto *buf = createAllocStack(toArg->getType());
355-
builder.createStore(loc, toArg, buf,
356-
StoreOwnershipQualifier::Unqualified);
410+
toArg = builder.emitCopyValueOperation(loc, toArg);
411+
builder.emitStoreValueOperation(loc, toArg, buf,
412+
StoreOwnershipQualifier::Init);
413+
valuesToCleanup.push_back(buf);
357414
arguments.push_back(buf);
358415
continue;
359416
}
360417
// Convert direct parameter to indirect parameter.
361418
assert(toParam.isFormalIndirect());
362-
auto *toArg = *toArgIter++;
363-
auto *load =
364-
builder.createLoad(loc, toArg, LoadOwnershipQualifier::Unqualified);
419+
auto toArg = *toArgIter++;
420+
auto load = builder.emitLoadBorrowOperation(loc, toArg);
421+
if (isa<LoadBorrowInst>(load))
422+
valuesToCleanup.push_back(load);
365423
arguments.push_back(load);
366424
}
367425

426+
auto *fnArg = thunk->getArgumentsWithoutIndirectResults().back();
368427
auto *apply = builder.createApply(loc, fnArg, SubstitutionMap(), arguments,
369428
/*isNonThrowing*/ false);
370429

@@ -413,8 +472,8 @@ SILFunction *getOrCreateReabstractionThunk(SILOptFunctionBuilder &fb,
413472
// Load direct results from indirect results.
414473
if (fromRes.isFormalIndirect()) {
415474
auto indRes = *fromIndResultsIter++;
416-
auto *load =
417-
builder.createLoad(loc, indRes, LoadOwnershipQualifier::Unqualified);
475+
auto load = builder.emitLoadValueOperation(loc, indRes,
476+
LoadOwnershipQualifier::Take);
418477
results.push_back(load);
419478
continue;
420479
}
@@ -426,11 +485,28 @@ SILFunction *getOrCreateReabstractionThunk(SILOptFunctionBuilder &fb,
426485
assert(resultTy.isAddress());
427486
#endif
428487
auto indRes = *toIndResultsIter++;
429-
builder.createStore(loc, *fromDirResultsIter++, indRes,
430-
StoreOwnershipQualifier::Unqualified);
488+
auto dirRes = *fromDirResultsIter++;
489+
builder.emitStoreValueOperation(loc, dirRes, indRes,
490+
StoreOwnershipQualifier::Init);
431491
}
432492
auto retVal = joinElements(results, builder, loc);
433493

494+
// Clean up local values.
495+
// Guaranteed values need an `end_borrow`.
496+
// Owned values need to be destroyed.
497+
for (auto arg : valuesToCleanup) {
498+
switch (arg.getOwnershipKind()) {
499+
case ValueOwnershipKind::Guaranteed:
500+
builder.emitEndBorrowOperation(loc, arg);
501+
break;
502+
case ValueOwnershipKind::Owned:
503+
case ValueOwnershipKind::Unowned:
504+
case ValueOwnershipKind::None:
505+
builder.emitDestroyOperation(loc, arg);
506+
break;
507+
}
508+
}
509+
434510
// Deallocate local allocations.
435511
for (auto *alloc : llvm::reverse(localAllocations))
436512
builder.createDeallocStack(loc, alloc);
@@ -549,11 +625,11 @@ getOrCreateSubsetParametersThunkForLinearMap(
549625
auto *buf = builder.createAllocStack(loc, zeroSILObjType);
550626
localAllocations.push_back(buf);
551627
emitZeroIntoBuffer(builder, zeroType, buf, loc);
552-
if (zeroSILType.isAddress())
628+
if (zeroSILType.isAddress()) {
553629
arguments.push_back(buf);
554-
else {
555-
auto *arg =
556-
builder.createLoad(loc, buf, LoadOwnershipQualifier::Unqualified);
630+
} else {
631+
auto arg = builder.emitLoadValueOperation(loc, buf,
632+
LoadOwnershipQualifier::Take);
557633
arguments.push_back(arg);
558634
}
559635
break;
@@ -810,8 +886,6 @@ getOrCreateSubsetParametersThunkForDerivativeFunction(
810886
if (!thunk->empty())
811887
return {thunk, interfaceSubs};
812888

813-
// TODO(TF-1206): Enable ownership in all differentiation thunks.
814-
thunk->setOwnershipEliminated();
815889
thunk->setGenericEnvironment(genericEnv);
816890
auto *entry = thunk->createBasicBlock();
817891
SILBuilder builder(entry);

0 commit comments

Comments
 (0)