Skip to content

Commit 9c95e27

Browse files
authored
[AutoDiff] Enable ownership for AD linear map SILGen thunks. (#26503)
Enable ownership for AD linear map thunk SILGen thunks. Ownership verification ensures that such thunks do not leak memory. Todos: - Much of this code is ad-hoc and manually written. Investigate how to generalize SIL reabstraction infrastructure. - Enable ownership for all AD SILGen thunks (namely associated function thunks).
1 parent 2f89554 commit 9c95e27

File tree

2 files changed

+53
-42
lines changed

2 files changed

+53
-42
lines changed

lib/SILGen/SILGenPoly.cpp

Lines changed: 50 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -3360,13 +3360,23 @@ static CanSILFunctionType buildWithoutActuallyEscapingThunkType(
33603360
// SWIFT_ENABLE_TENSORFLOW
33613361
/// Given a value, extracts all elements to `result` from this value if it's a
33623362
/// tuple. Otherwise, add this value directly to `result`.
3363-
static void extractAllElements(SILValue val, SILBuilder &builder,
3363+
static void extractAllElements(SILValue val, SILLocation loc,
3364+
SILBuilder &builder,
33643365
SmallVectorImpl<SILValue> &result) {
3365-
if (auto tupleType = val->getType().getAs<TupleType>())
3366-
for (auto i : range(tupleType->getNumElements()))
3367-
result.push_back(builder.createTupleExtract(val.getLoc(), val, i));
3368-
else
3366+
auto &fn = builder.getFunction();
3367+
auto tupleType = val->getType().getAs<TupleType>();
3368+
if (!tupleType) {
33693369
result.push_back(val);
3370+
return;
3371+
}
3372+
if (!fn.hasOwnership()) {
3373+
for (auto i : range(tupleType->getNumElements()))
3374+
result.push_back(builder.createTupleExtract(loc, val, i));
3375+
return;
3376+
}
3377+
if (tupleType->getNumElements() == 0)
3378+
return;
3379+
builder.emitDestructureValueOperation(loc, val, result);
33703380
}
33713381

33723382
// SWIFT_ENABLE_TENSORFLOW
@@ -3408,17 +3418,18 @@ SILGenFunction::getThunkedAutoDiffLinearMap(
34083418
thunkType, fromInterfaceType, toInterfaceType,
34093419
Type(), getModule().getSwiftModule());
34103420
// TODO(TF-685): Use principled thunk mangling.
3411-
if (reorderSelf) {
3412-
switch (assocFnKind) {
3413-
case AutoDiffAssociatedFunctionKind::JVP:
3414-
name += "_differential";
3415-
break;
3416-
case AutoDiffAssociatedFunctionKind::VJP:
3417-
name += "_pullback";
3418-
break;
3419-
}
3420-
name = "AD__" + name + "_self_reordering_thunk";
3421+
switch (assocFnKind) {
3422+
case AutoDiffAssociatedFunctionKind::JVP:
3423+
name += "_differential";
3424+
break;
3425+
case AutoDiffAssociatedFunctionKind::VJP:
3426+
name += "_pullback";
3427+
break;
34213428
}
3429+
name = "AD__" + name;
3430+
if (reorderSelf)
3431+
name += "_self_reordering";
3432+
name += "_thunk";
34223433

34233434
// Create the thunk.
34243435
auto loc = F.getLocation();
@@ -3441,7 +3452,6 @@ SILGenFunction::getThunkedAutoDiffLinearMap(
34413452
if (!thunk->empty())
34423453
return getThunkedResult();
34433454
thunk->setGenericEnvironment(genericEnv);
3444-
thunk->setOwnershipEliminated();
34453455

34463456
SILGenFunction thunkSGF(SGM, *thunk, FunctionDC);
34473457
SmallVector<ManagedValue, 4> params;
@@ -3452,8 +3462,10 @@ SILGenFunction::getThunkedAutoDiffLinearMap(
34523462
SILFunctionConventions toConv(toType, getModule());
34533463
assert(toConv.useLoweredAddresses());
34543464

3455-
SmallVector<SILArgument *, 4> thunkArguments(thunk->getArguments().begin(),
3456-
thunk->getArguments().end());
3465+
SmallVector<ManagedValue, 4> thunkArguments;
3466+
for (auto *indRes : thunkIndirectResults)
3467+
thunkArguments.push_back(ManagedValue::forLValue(indRes));
3468+
thunkArguments.append(params.begin(), params.end());
34573469
SmallVector<SILParameterInfo, 4> toParameters(toConv.getParameters().begin(),
34583470
toConv.getParameters().end());
34593471
SmallVector<SILResultInfo, 4> toResults(toConv.getResults().begin(),
@@ -3472,17 +3484,13 @@ SILGenFunction::getThunkedAutoDiffLinearMap(
34723484
thunkArguments.begin() + numIndirectResults - 1,
34733485
thunkArguments.begin() + numIndirectResults);
34743486
}
3475-
std::rotate(toResults.begin(),
3476-
toResults.end() - 1,
3477-
toResults.end());
3487+
std::rotate(toResults.begin(), toResults.end() - 1, toResults.end());
34783488
}
34793489
if (reorderSelf && assocFnKind == AutoDiffAssociatedFunctionKind::JVP &&
34803490
thunkArguments.size() > 1) {
34813491
std::rotate(thunkArguments.begin() + numIndirectResults,
3482-
thunkArguments.end() - 2,
3483-
thunkArguments.end() - 1);
3484-
std::rotate(toParameters.begin(),
3485-
toParameters.end() - 1,
3492+
thunkArguments.end() - 2, thunkArguments.end() - 1);
3493+
std::rotate(toParameters.begin(), toParameters.end() - 1,
34863494
toParameters.end());
34873495
}
34883496

@@ -3506,7 +3514,8 @@ SILGenFunction::getThunkedAutoDiffLinearMap(
35063514
SmallVector<SILValue, 4> arguments;
35073515
auto toArgIter = thunkArguments.begin();
35083516
auto useNextArgument = [&]() {
3509-
arguments.push_back(*toArgIter++);
3517+
auto nextArgument = *toArgIter++;
3518+
arguments.push_back(nextArgument.getValue());
35103519
};
35113520

35123521
SmallVector<AllocStackInst *, 4> localAllocations;
@@ -3555,19 +3564,17 @@ SILGenFunction::getThunkedAutoDiffLinearMap(
35553564
if (!paramTy.hasArchetype())
35563565
paramTy = thunk->mapTypeIntoContext(paramTy);
35573566
assert(paramTy.isAddress());
3558-
auto *toArg = *toArgIter++;
3567+
auto toArg = (*toArgIter++).getValue();
35593568
auto *buf = createAllocStack(toArg->getType());
3560-
thunkSGF.B.createStore(
3561-
loc, toArg, buf, StoreOwnershipQualifier::Unqualified);
3569+
thunkSGF.B.createStore(loc, toArg, buf, StoreOwnershipQualifier::Init);
35623570
arguments.push_back(buf);
35633571
continue;
35643572
}
35653573
// Convert direct parameter to indirect parameter.
35663574
assert(toParam.isFormalIndirect());
3567-
auto *toArg = *toArgIter++;
3568-
auto *load =
3569-
thunkSGF.B.createLoad(loc, toArg, LoadOwnershipQualifier::Unqualified);
3570-
arguments.push_back(load);
3575+
auto toArg = (*toArgIter++).getValue();
3576+
auto load = thunkSGF.emitManagedLoadBorrow(loc, toArg);
3577+
arguments.push_back(load.getValue());
35713578
}
35723579

35733580
auto *linearMapArg = thunk->getArgumentsWithoutIndirectResults().back();
@@ -3578,7 +3585,7 @@ SILGenFunction::getThunkedAutoDiffLinearMap(
35783585
SmallVector<SILValue, 4> results;
35793586
// Extract all direct results.
35803587
SmallVector<SILValue, 4> directResults;
3581-
extractAllElements(apply, thunkSGF.B, directResults);
3588+
extractAllElements(apply, loc, thunkSGF.B, directResults);
35823589

35833590
// Handle self reordering.
35843591
// For pullbacks: rotate direct results if self is direct.
@@ -3629,11 +3636,16 @@ SILGenFunction::getThunkedAutoDiffLinearMap(
36293636
SILType resultTy = toConv.getSILType(toRes);
36303637
assert(resultTy.isAddress());
36313638
auto indRes = *toIndResultsIter++;
3632-
thunkSGF.B.createStore(loc, *fromDirResultsIter++, indRes,
3633-
StoreOwnershipQualifier::Unqualified);
3639+
thunkSGF.emitSemanticStore(loc, *fromDirResultsIter++, indRes,
3640+
thunkSGF.getTypeLowering(resultTy),
3641+
IsInitialization);
36343642
}
36353643
auto retVal = joinElements(results, thunkSGF.B, loc);
36363644

3645+
// Emit cleanups.
3646+
thunkSGF.Cleanups.emitCleanupsForReturn(
3647+
CleanupLocation::get(loc), NotForUnwind);
3648+
36373649
// Deallocate local allocations.
36383650
for (auto *alloc : reversed(localAllocations))
36393651
thunkSGF.B.createDeallocStack(loc, alloc);
@@ -3758,8 +3770,8 @@ SILGenModule::getOrCreateAutoDiffAssociatedFunctionThunk(
37583770
thunk->getForwardingSubstitutionMap(), arguments);
37593771

37603772
SmallVector<SILValue, 8> directResults;
3761-
extractAllElements(apply, thunkSGF.B, directResults);
3762-
auto linearMap = ManagedValue::forBorrowedObjectRValue(directResults.back());
3773+
extractAllElements(apply, loc, thunkSGF.B, directResults);
3774+
auto linearMap = ManagedValue::forUnmanaged(directResults.back());
37633775
auto linearMapFnType = linearMap.getType().castTo<SILFunctionType>();
37643776
auto targetLinearMapFnType = thunk->mapTypeIntoContext(
37653777
origAssocFnType->getResults().back().getSILStorageType())

test/AutoDiff/silgen_thunking/main.swift

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,11 +59,10 @@ struct TF_698 : Differentiable & AdditiveArithmetic {
5959
// CHECK: [[RESULT:%.*]] = tuple ([[VJP_ORIG_RESULT]] : $TF_698, [[THUNKED_PB]] : {{.*}})
6060
// CHECK: return [[RESULT]]
6161

62-
// CHECK-LABEL: sil shared [transparent] [serialized] [reabstraction_thunk] @AD__$s4main6TF_698VA2CIeggoo_A3CIeggoo_TR_pullback_self_reordering_thunk : $@convention(thin) (@guaranteed TF_698, @guaranteed @callee_guaranteed (@guaranteed TF_698) -> (@owned TF_698, @owned TF_698))
63-
// CHECK: bb0([[SEED:%.*]] : $TF_698, [[PB:%.*]] : $@callee_guaranteed (@guaranteed TF_698) -> (@owned TF_698, @owned TF_698)):
62+
// CHECK-LABEL: sil shared [transparent] [serialized] [reabstraction_thunk] [ossa] @AD__$s4main6TF_698VA2CIeggoo_A3CIeggoo_TR_pullback_self_reordering_thunk : $@convention(thin) (@guaranteed TF_698, @guaranteed @callee_guaranteed (@guaranteed TF_698) -> (@owned TF_698, @owned TF_698))
63+
// CHECK: bb0([[SEED:%.*]] : @guaranteed $TF_698, [[PB:%.*]] : @guaranteed $@callee_guaranteed (@guaranteed TF_698) -> (@owned TF_698, @owned TF_698)):
6464
// CHECK: [[PB_RESULT:%.*]] = apply [[PB]]([[SEED]])
65-
// CHECK: [[X_ADJ:%.*]] = tuple_extract [[PB_RESULT]] : $(TF_698, TF_698), 0
66-
// CHECK: [[Y_ADJ:%.*]] = tuple_extract [[PB_RESULT]] : $(TF_698, TF_698), 1
65+
// CHECK: ([[X_ADJ:%.*]], [[Y_ADJ:%.*]]) = destructure_tuple %2 : $(TF_698, TF_698)
6766
// CHECK: [[RESULT:%.*]] = tuple ([[Y_ADJ]] : $TF_698, [[X_ADJ]] : $TF_698)
6867
// CHECK: return [[RESULT]]
6968

0 commit comments

Comments
 (0)