@@ -4970,7 +4970,86 @@ ManagedValue SILGenFunction::getThunkedAutoDiffLinearMap(
4970
4970
4971
4971
SILFunctionConventions fromConv (fromType, getModule ());
4972
4972
SILFunctionConventions toConv (toType, getModule ());
4973
- assert (toConv.useLoweredAddresses ());
4973
+ if (!toConv.useLoweredAddresses ()) {
4974
+ SmallVector<ManagedValue, 4 > thunkArguments;
4975
+ for (auto *indRes : thunkIndirectResults)
4976
+ thunkArguments.push_back (ManagedValue::forLValue (indRes));
4977
+ thunkArguments.append (params.begin (), params.end ());
4978
+ SmallVector<SILParameterInfo, 4 > toParameters (
4979
+ toConv.getParameters ().begin (), toConv.getParameters ().end ());
4980
+ SmallVector<SILResultInfo, 4 > toResults (toConv.getResults ().begin (),
4981
+ toConv.getResults ().end ());
4982
+ // Handle self reordering.
4983
+ // - For pullbacks: reorder result infos.
4984
+ // - For differentials: reorder parameter infos and arguments.
4985
+ auto numIndirectResults = thunkIndirectResults.size ();
4986
+ if (reorderSelf && linearMapKind == AutoDiffLinearMapKind::Pullback &&
4987
+ toResults.size () > 1 ) {
4988
+ std::rotate (toResults.begin (), toResults.end () - 1 , toResults.end ());
4989
+ }
4990
+ if (reorderSelf && linearMapKind == AutoDiffLinearMapKind::Differential &&
4991
+ thunkArguments.size () > 1 ) {
4992
+ // Before: [arg1, arg2, ..., arg_self, df]
4993
+ // After: [arg_self, arg1, arg2, ..., df]
4994
+ std::rotate (thunkArguments.begin () + numIndirectResults,
4995
+ thunkArguments.end () - 2 , thunkArguments.end () - 1 );
4996
+ // Before: [arg1, arg2, ..., arg_self]
4997
+ // After: [arg_self, arg1, arg2, ...]
4998
+ std::rotate (toParameters.begin (), toParameters.end () - 1 ,
4999
+ toParameters.end ());
5000
+ }
5001
+
5002
+ // Correctness assertions.
5003
+ #ifndef NDEBUG
5004
+ assert (toType->getNumParameters () == fromType->getNumParameters ());
5005
+ for (unsigned paramIdx : range (toType->getNumParameters ())) {
5006
+ auto fromParam = fromConv.getParameters ()[paramIdx];
5007
+ auto toParam = toParameters[paramIdx];
5008
+ assert (fromParam.getInterfaceType () == toParam.getInterfaceType ());
5009
+ }
5010
+ assert (fromType->getNumResults () == toType->getNumResults ());
5011
+ for (unsigned resIdx : range (toType->getNumResults ())) {
5012
+ auto fromRes = fromConv.getResults ()[resIdx];
5013
+ auto toRes = toResults[resIdx];
5014
+ assert (fromRes.getInterfaceType () == toRes.getInterfaceType ());
5015
+ }
5016
+ #endif // NDEBUG
5017
+
5018
+ auto *linearMapArg = thunk->getArguments ().back ();
5019
+ SmallVector<SILValue, 4 > arguments;
5020
+ for (unsigned paramIdx : range (toType->getNumParameters ())) {
5021
+ arguments.push_back (thunkArguments[paramIdx].getValue ());
5022
+ }
5023
+ auto *apply =
5024
+ thunkSGF.B .createApply (loc, linearMapArg, SubstitutionMap (), arguments);
5025
+
5026
+ // Get return elements.
5027
+ SmallVector<SILValue, 4 > results;
5028
+ extractAllElements (apply, loc, thunkSGF.B , results);
5029
+
5030
+ // Handle self reordering.
5031
+ // For pullbacks: rotate direct results if self is direct.
5032
+ if (reorderSelf && linearMapKind == AutoDiffLinearMapKind::Pullback) {
5033
+ auto fromSelfResult = fromConv.getResults ().front ();
5034
+ auto toSelfResult = toConv.getResults ().back ();
5035
+ assert (fromSelfResult.getInterfaceType () ==
5036
+ toSelfResult.getInterfaceType ());
5037
+ // Before: [dir_res_self, dir_res1, dir_res2, ...]
5038
+ // After: [dir_res1, dir_res2, ..., dir_res_self]
5039
+ if (results.size () > 1 ) {
5040
+ std::rotate (results.begin (), results.begin () + 1 , results.end ());
5041
+ }
5042
+ }
5043
+ auto retVal = joinElements (results, thunkSGF.B , loc);
5044
+
5045
+ // Emit cleanups.
5046
+ thunkSGF.Cleanups .emitCleanupsForReturn (CleanupLocation (loc), NotForUnwind);
5047
+
5048
+ // Create return.
5049
+ thunkSGF.B .createReturn (loc, retVal);
5050
+
5051
+ return getThunkedResult ();
5052
+ }
4974
5053
4975
5054
SmallVector<ManagedValue, 4 > thunkArguments;
4976
5055
for (auto *indRes : thunkIndirectResults)
@@ -5308,7 +5387,10 @@ SILFunction *SILGenModule::getOrCreateCustomDerivativeThunk(
5308
5387
};
5309
5388
5310
5389
if (!reorderSelf && linearMapFnType == targetLinearMapFnType) {
5311
- createReturn (apply);
5390
+ SmallVector<SILValue, 8 > results;
5391
+ extractAllElements (apply, loc, thunkSGF.B , results);
5392
+ auto result = joinElements (results, thunkSGF.B , apply.getLoc ());
5393
+ createReturn (result);
5312
5394
return thunk;
5313
5395
}
5314
5396
0 commit comments