@@ -3496,7 +3496,86 @@ ManagedValue SILGenFunction::getThunkedAutoDiffLinearMap(
3496
3496
3497
3497
SILFunctionConventions fromConv (fromType, getModule ());
3498
3498
SILFunctionConventions toConv (toType, getModule ());
3499
- assert (toConv.useLoweredAddresses ());
3499
+ if (!toConv.useLoweredAddresses ()) {
3500
+ SmallVector<ManagedValue, 4 > thunkArguments;
3501
+ for (auto *indRes : thunkIndirectResults)
3502
+ thunkArguments.push_back (ManagedValue::forLValue (indRes));
3503
+ thunkArguments.append (params.begin (), params.end ());
3504
+ SmallVector<SILParameterInfo, 4 > toParameters (
3505
+ toConv.getParameters ().begin (), toConv.getParameters ().end ());
3506
+ SmallVector<SILResultInfo, 4 > toResults (toConv.getResults ().begin (),
3507
+ toConv.getResults ().end ());
3508
+ // Handle self reordering.
3509
+ // - For pullbacks: reorder result infos.
3510
+ // - For differentials: reorder parameter infos and arguments.
3511
+ auto numIndirectResults = thunkIndirectResults.size ();
3512
+ if (reorderSelf && linearMapKind == AutoDiffLinearMapKind::Pullback &&
3513
+ toResults.size () > 1 ) {
3514
+ std::rotate (toResults.begin (), toResults.end () - 1 , toResults.end ());
3515
+ }
3516
+ if (reorderSelf && linearMapKind == AutoDiffLinearMapKind::Differential &&
3517
+ thunkArguments.size () > 1 ) {
3518
+ // Before: [arg1, arg2, ..., arg_self, df]
3519
+ // After: [arg_self, arg1, arg2, ..., df]
3520
+ std::rotate (thunkArguments.begin () + numIndirectResults,
3521
+ thunkArguments.end () - 2 , thunkArguments.end () - 1 );
3522
+ // Before: [arg1, arg2, ..., arg_self]
3523
+ // After: [arg_self, arg1, arg2, ...]
3524
+ std::rotate (toParameters.begin (), toParameters.end () - 1 ,
3525
+ toParameters.end ());
3526
+ }
3527
+
3528
+ // Correctness assertions.
3529
+ #ifndef NDEBUG
3530
+ assert (toType->getNumParameters () == fromType->getNumParameters ());
3531
+ for (unsigned paramIdx : range (toType->getNumParameters ())) {
3532
+ auto fromParam = fromConv.getParameters ()[paramIdx];
3533
+ auto toParam = toParameters[paramIdx];
3534
+ assert (fromParam.getInterfaceType () == toParam.getInterfaceType ());
3535
+ }
3536
+ assert (fromType->getNumResults () == toType->getNumResults ());
3537
+ for (unsigned resIdx : range (toType->getNumResults ())) {
3538
+ auto fromRes = fromConv.getResults ()[resIdx];
3539
+ auto toRes = toResults[resIdx];
3540
+ assert (fromRes.getInterfaceType () == toRes.getInterfaceType ());
3541
+ }
3542
+ #endif // NDEBUG
3543
+
3544
+ auto *linearMapArg = thunk->getArguments ().back ();
3545
+ SmallVector<SILValue, 4 > arguments;
3546
+ for (unsigned paramIdx : range (toType->getNumParameters ())) {
3547
+ arguments.push_back (thunkArguments[paramIdx].getValue ());
3548
+ }
3549
+ auto *apply =
3550
+ thunkSGF.B .createApply (loc, linearMapArg, SubstitutionMap (), arguments);
3551
+
3552
+ // Get return elements.
3553
+ SmallVector<SILValue, 4 > results;
3554
+ extractAllElements (apply, loc, thunkSGF.B , results);
3555
+
3556
+ // Handle self reordering.
3557
+ // For pullbacks: rotate direct results if self is direct.
3558
+ if (reorderSelf && linearMapKind == AutoDiffLinearMapKind::Pullback) {
3559
+ auto fromSelfResult = fromConv.getResults ().front ();
3560
+ auto toSelfResult = toConv.getResults ().back ();
3561
+ assert (fromSelfResult.getInterfaceType () ==
3562
+ toSelfResult.getInterfaceType ());
3563
+ // Before: [dir_res_self, dir_res1, dir_res2, ...]
3564
+ // After: [dir_res1, dir_res2, ..., dir_res_self]
3565
+ if (results.size () > 1 ) {
3566
+ std::rotate (results.begin (), results.begin () + 1 , results.end ());
3567
+ }
3568
+ }
3569
+ auto retVal = joinElements (results, thunkSGF.B , loc);
3570
+
3571
+ // Emit cleanups.
3572
+ thunkSGF.Cleanups .emitCleanupsForReturn (CleanupLocation (loc), NotForUnwind);
3573
+
3574
+ // Create return.
3575
+ thunkSGF.B .createReturn (loc, retVal);
3576
+
3577
+ return getThunkedResult ();
3578
+ }
3500
3579
3501
3580
SmallVector<ManagedValue, 4 > thunkArguments;
3502
3581
for (auto *indRes : thunkIndirectResults)
@@ -3833,7 +3912,10 @@ SILFunction *SILGenModule::getOrCreateCustomDerivativeThunk(
3833
3912
};
3834
3913
3835
3914
if (!reorderSelf && linearMapFnType == targetLinearMapFnType) {
3836
- createReturn (apply);
3915
+ SmallVector<SILValue, 8 > results;
3916
+ extractAllElements (apply, loc, thunkSGF.B , results);
3917
+ auto result = joinElements (results, thunkSGF.B , apply.getLoc ());
3918
+ createReturn (result);
3837
3919
return thunk;
3838
3920
}
3839
3921
0 commit comments