Skip to content

Commit ce4c7f5

Browse files
committed
[OpaqueValues] Initial support for AD.
Just enough to build _Differentiation.
1 parent 0b05a1e commit ce4c7f5

File tree

2 files changed

+86
-3
lines changed

2 files changed

+86
-3
lines changed

lib/SILGen/SILGenBuiltin.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1201,7 +1201,8 @@ static ManagedValue emitBuiltinAutoDiffApplyDerivativeFunction(
12011201
assert(derivativeFn->getType().isTrivial(SGF.F));
12021202

12031203
// Do the apply for the indirect result case.
1204-
if (derivativeFnType->hasIndirectFormalResults()) {
1204+
if (derivativeFnType->hasIndirectFormalResults() &&
1205+
SGF.SGM.M.useLoweredAddresses()) {
12051206
auto indResBuffer = SGF.getBufferForExprResult(
12061207
loc, derivativeFnType->getAllResultsInterfaceType(), C);
12071208
SmallVector<SILValue, 3> applyArgs;

lib/SILGen/SILGenPoly.cpp

Lines changed: 84 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3496,7 +3496,86 @@ ManagedValue SILGenFunction::getThunkedAutoDiffLinearMap(
34963496

34973497
SILFunctionConventions fromConv(fromType, getModule());
34983498
SILFunctionConventions toConv(toType, getModule());
3499-
assert(toConv.useLoweredAddresses());
3499+
if (!toConv.useLoweredAddresses()) {
3500+
SmallVector<ManagedValue, 4> thunkArguments;
3501+
for (auto *indRes : thunkIndirectResults)
3502+
thunkArguments.push_back(ManagedValue::forLValue(indRes));
3503+
thunkArguments.append(params.begin(), params.end());
3504+
SmallVector<SILParameterInfo, 4> toParameters(
3505+
toConv.getParameters().begin(), toConv.getParameters().end());
3506+
SmallVector<SILResultInfo, 4> toResults(toConv.getResults().begin(),
3507+
toConv.getResults().end());
3508+
// Handle self reordering.
3509+
// - For pullbacks: reorder result infos.
3510+
// - For differentials: reorder parameter infos and arguments.
3511+
auto numIndirectResults = thunkIndirectResults.size();
3512+
if (reorderSelf && linearMapKind == AutoDiffLinearMapKind::Pullback &&
3513+
toResults.size() > 1) {
3514+
std::rotate(toResults.begin(), toResults.end() - 1, toResults.end());
3515+
}
3516+
if (reorderSelf && linearMapKind == AutoDiffLinearMapKind::Differential &&
3517+
thunkArguments.size() > 1) {
3518+
// Before: [arg1, arg2, ..., arg_self, df]
3519+
// After: [arg_self, arg1, arg2, ..., df]
3520+
std::rotate(thunkArguments.begin() + numIndirectResults,
3521+
thunkArguments.end() - 2, thunkArguments.end() - 1);
3522+
// Before: [arg1, arg2, ..., arg_self]
3523+
// After: [arg_self, arg1, arg2, ...]
3524+
std::rotate(toParameters.begin(), toParameters.end() - 1,
3525+
toParameters.end());
3526+
}
3527+
3528+
// Correctness assertions.
3529+
#ifndef NDEBUG
3530+
assert(toType->getNumParameters() == fromType->getNumParameters());
3531+
for (unsigned paramIdx : range(toType->getNumParameters())) {
3532+
auto fromParam = fromConv.getParameters()[paramIdx];
3533+
auto toParam = toParameters[paramIdx];
3534+
assert(fromParam.getInterfaceType() == toParam.getInterfaceType());
3535+
}
3536+
assert(fromType->getNumResults() == toType->getNumResults());
3537+
for (unsigned resIdx : range(toType->getNumResults())) {
3538+
auto fromRes = fromConv.getResults()[resIdx];
3539+
auto toRes = toResults[resIdx];
3540+
assert(fromRes.getInterfaceType() == toRes.getInterfaceType());
3541+
}
3542+
#endif // NDEBUG
3543+
3544+
auto *linearMapArg = thunk->getArguments().back();
3545+
SmallVector<SILValue, 4> arguments;
3546+
for (unsigned paramIdx : range(toType->getNumParameters())) {
3547+
arguments.push_back(thunkArguments[paramIdx].getValue());
3548+
}
3549+
auto *apply =
3550+
thunkSGF.B.createApply(loc, linearMapArg, SubstitutionMap(), arguments);
3551+
3552+
// Get return elements.
3553+
SmallVector<SILValue, 4> results;
3554+
extractAllElements(apply, loc, thunkSGF.B, results);
3555+
3556+
// Handle self reordering.
3557+
// For pullbacks: rotate direct results if self is direct.
3558+
if (reorderSelf && linearMapKind == AutoDiffLinearMapKind::Pullback) {
3559+
auto fromSelfResult = fromConv.getResults().front();
3560+
auto toSelfResult = toConv.getResults().back();
3561+
assert(fromSelfResult.getInterfaceType() ==
3562+
toSelfResult.getInterfaceType());
3563+
// Before: [dir_res_self, dir_res1, dir_res2, ...]
3564+
// After: [dir_res1, dir_res2, ..., dir_res_self]
3565+
if (results.size() > 1) {
3566+
std::rotate(results.begin(), results.begin() + 1, results.end());
3567+
}
3568+
}
3569+
auto retVal = joinElements(results, thunkSGF.B, loc);
3570+
3571+
// Emit cleanups.
3572+
thunkSGF.Cleanups.emitCleanupsForReturn(CleanupLocation(loc), NotForUnwind);
3573+
3574+
// Create return.
3575+
thunkSGF.B.createReturn(loc, retVal);
3576+
3577+
return getThunkedResult();
3578+
}
35003579

35013580
SmallVector<ManagedValue, 4> thunkArguments;
35023581
for (auto *indRes : thunkIndirectResults)
@@ -3833,7 +3912,10 @@ SILFunction *SILGenModule::getOrCreateCustomDerivativeThunk(
38333912
};
38343913

38353914
if (!reorderSelf && linearMapFnType == targetLinearMapFnType) {
3836-
createReturn(apply);
3915+
SmallVector<SILValue, 8> results;
3916+
extractAllElements(apply, loc, thunkSGF.B, results);
3917+
auto result = joinElements(results, thunkSGF.B, apply.getLoc());
3918+
createReturn(result);
38373919
return thunk;
38383920
}
38393921

0 commit comments

Comments
 (0)