Skip to content

Commit d13348c

Browse files
Merge pull request #62494 from nate-chandler/opaque-values/2/20221209
[OpaqueValues] Initial support for AD.
2 parents b8b1e34 + c8bce4a commit d13348c

File tree

2 files changed

+86
-3
lines changed

2 files changed

+86
-3
lines changed

lib/SILGen/SILGenBuiltin.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1200,7 +1200,8 @@ static ManagedValue emitBuiltinAutoDiffApplyDerivativeFunction(
12001200
assert(derivativeFnType->isTrivialNoEscape());
12011201

12021202
// Do the apply for the indirect result case.
1203-
if (derivativeFnType->hasIndirectFormalResults()) {
1203+
if (derivativeFnType->hasIndirectFormalResults() &&
1204+
SGF.SGM.M.useLoweredAddresses()) {
12041205
auto indResBuffer = SGF.getBufferForExprResult(
12051206
loc, derivativeFnType->getAllResultsInterfaceType(), C);
12061207
SmallVector<SILValue, 3> applyArgs;

lib/SILGen/SILGenPoly.cpp

Lines changed: 84 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4970,7 +4970,86 @@ ManagedValue SILGenFunction::getThunkedAutoDiffLinearMap(
49704970

49714971
SILFunctionConventions fromConv(fromType, getModule());
49724972
SILFunctionConventions toConv(toType, getModule());
4973-
assert(toConv.useLoweredAddresses());
4973+
if (!toConv.useLoweredAddresses()) {
4974+
SmallVector<ManagedValue, 4> thunkArguments;
4975+
for (auto *indRes : thunkIndirectResults)
4976+
thunkArguments.push_back(ManagedValue::forLValue(indRes));
4977+
thunkArguments.append(params.begin(), params.end());
4978+
SmallVector<SILParameterInfo, 4> toParameters(
4979+
toConv.getParameters().begin(), toConv.getParameters().end());
4980+
SmallVector<SILResultInfo, 4> toResults(toConv.getResults().begin(),
4981+
toConv.getResults().end());
4982+
// Handle self reordering.
4983+
// - For pullbacks: reorder result infos.
4984+
// - For differentials: reorder parameter infos and arguments.
4985+
auto numIndirectResults = thunkIndirectResults.size();
4986+
if (reorderSelf && linearMapKind == AutoDiffLinearMapKind::Pullback &&
4987+
toResults.size() > 1) {
4988+
std::rotate(toResults.begin(), toResults.end() - 1, toResults.end());
4989+
}
4990+
if (reorderSelf && linearMapKind == AutoDiffLinearMapKind::Differential &&
4991+
thunkArguments.size() > 1) {
4992+
// Before: [arg1, arg2, ..., arg_self, df]
4993+
// After: [arg_self, arg1, arg2, ..., df]
4994+
std::rotate(thunkArguments.begin() + numIndirectResults,
4995+
thunkArguments.end() - 2, thunkArguments.end() - 1);
4996+
// Before: [arg1, arg2, ..., arg_self]
4997+
// After: [arg_self, arg1, arg2, ...]
4998+
std::rotate(toParameters.begin(), toParameters.end() - 1,
4999+
toParameters.end());
5000+
}
5001+
5002+
// Correctness assertions.
5003+
#ifndef NDEBUG
5004+
assert(toType->getNumParameters() == fromType->getNumParameters());
5005+
for (unsigned paramIdx : range(toType->getNumParameters())) {
5006+
auto fromParam = fromConv.getParameters()[paramIdx];
5007+
auto toParam = toParameters[paramIdx];
5008+
assert(fromParam.getInterfaceType() == toParam.getInterfaceType());
5009+
}
5010+
assert(fromType->getNumResults() == toType->getNumResults());
5011+
for (unsigned resIdx : range(toType->getNumResults())) {
5012+
auto fromRes = fromConv.getResults()[resIdx];
5013+
auto toRes = toResults[resIdx];
5014+
assert(fromRes.getInterfaceType() == toRes.getInterfaceType());
5015+
}
5016+
#endif // NDEBUG
5017+
5018+
auto *linearMapArg = thunk->getArguments().back();
5019+
SmallVector<SILValue, 4> arguments;
5020+
for (unsigned paramIdx : range(toType->getNumParameters())) {
5021+
arguments.push_back(thunkArguments[paramIdx].getValue());
5022+
}
5023+
auto *apply =
5024+
thunkSGF.B.createApply(loc, linearMapArg, SubstitutionMap(), arguments);
5025+
5026+
// Get return elements.
5027+
SmallVector<SILValue, 4> results;
5028+
extractAllElements(apply, loc, thunkSGF.B, results);
5029+
5030+
// Handle self reordering.
5031+
// For pullbacks: rotate direct results if self is direct.
5032+
if (reorderSelf && linearMapKind == AutoDiffLinearMapKind::Pullback) {
5033+
auto fromSelfResult = fromConv.getResults().front();
5034+
auto toSelfResult = toConv.getResults().back();
5035+
assert(fromSelfResult.getInterfaceType() ==
5036+
toSelfResult.getInterfaceType());
5037+
// Before: [dir_res_self, dir_res1, dir_res2, ...]
5038+
// After: [dir_res1, dir_res2, ..., dir_res_self]
5039+
if (results.size() > 1) {
5040+
std::rotate(results.begin(), results.begin() + 1, results.end());
5041+
}
5042+
}
5043+
auto retVal = joinElements(results, thunkSGF.B, loc);
5044+
5045+
// Emit cleanups.
5046+
thunkSGF.Cleanups.emitCleanupsForReturn(CleanupLocation(loc), NotForUnwind);
5047+
5048+
// Create return.
5049+
thunkSGF.B.createReturn(loc, retVal);
5050+
5051+
return getThunkedResult();
5052+
}
49745053

49755054
SmallVector<ManagedValue, 4> thunkArguments;
49765055
for (auto *indRes : thunkIndirectResults)
@@ -5308,7 +5387,10 @@ SILFunction *SILGenModule::getOrCreateCustomDerivativeThunk(
53085387
};
53095388

53105389
if (!reorderSelf && linearMapFnType == targetLinearMapFnType) {
5311-
createReturn(apply);
5390+
SmallVector<SILValue, 8> results;
5391+
extractAllElements(apply, loc, thunkSGF.B, results);
5392+
auto result = joinElements(results, thunkSGF.B, apply.getLoc());
5393+
createReturn(result);
53125394
return thunk;
53135395
}
53145396

0 commit comments

Comments
 (0)