@@ -3394,7 +3394,7 @@ static SILValue joinElements(ArrayRef<SILValue> elements, SILBuilder &builder,
3394
3394
// / Adapted from `SILGenModule::getOrCreateReabstractionThunk`.
3395
3395
ManagedValue
3396
3396
SILGenFunction::getThunkedAutoDiffLinearMap (
3397
- ManagedValue linearMap, AutoDiffAssociatedFunctionKind assocFnKind ,
3397
+ ManagedValue linearMap, AutoDiffLinearMapKind linearMapKind ,
3398
3398
CanSILFunctionType fromType, CanSILFunctionType toType,
3399
3399
bool reorderSelf) {
3400
3400
// Compute the thunk type.
@@ -3418,11 +3418,11 @@ SILGenFunction::getThunkedAutoDiffLinearMap(
3418
3418
thunkType, fromInterfaceType, toInterfaceType,
3419
3419
Type (), getModule ().getSwiftModule ());
3420
3420
// TODO(TF-685): Use principled thunk mangling.
3421
- switch (assocFnKind ) {
3422
- case AutoDiffAssociatedFunctionKind::JVP :
3421
+ switch (linearMapKind ) {
3422
+ case AutoDiffLinearMapKind::Differential :
3423
3423
name += " _differential" ;
3424
3424
break ;
3425
- case AutoDiffAssociatedFunctionKind::VJP :
3425
+ case AutoDiffLinearMapKind::Pullback :
3426
3426
name += " _pullback" ;
3427
3427
break ;
3428
3428
}
@@ -3476,20 +3476,30 @@ SILGenFunction::getThunkedAutoDiffLinearMap(
3476
3476
// - If self is direct, reorder direct results after `apply` is generated.
3477
3477
// - For differentials: reorder parameter infos and arguments.
3478
3478
auto numIndirectResults = thunkIndirectResults.size ();
3479
- if (reorderSelf && assocFnKind == AutoDiffAssociatedFunctionKind::VJP &&
3479
+ if (reorderSelf && linearMapKind == AutoDiffLinearMapKind::Pullback &&
3480
3480
toResults.size () > 1 ) {
3481
3481
auto toSelfResult = toResults.back ();
3482
3482
if (toSelfResult.isFormalIndirect () && numIndirectResults > 1 ) {
3483
+ // Before: [ind_res1, ind_res2, ..., ind_res_self, arg1, arg2, ..., pb]
3484
+ // After: [ind_res_self, ind_res1, ind_res2, ..., arg1, arg2, ..., pb]
3483
3485
std::rotate (thunkArguments.begin (),
3484
3486
thunkArguments.begin () + numIndirectResults - 1 ,
3485
3487
thunkArguments.begin () + numIndirectResults);
3488
+ // Before: [ind_res1, ind_res2, ..., ind_res_self]
3489
+ // After: [ind_res_self, ind_res1, ind_res2, ...]
3490
+ std::rotate (thunkIndirectResults.begin (), thunkIndirectResults.end () - 1 ,
3491
+ thunkIndirectResults.end ());
3486
3492
}
3487
3493
std::rotate (toResults.begin (), toResults.end () - 1 , toResults.end ());
3488
3494
}
3489
- if (reorderSelf && assocFnKind == AutoDiffAssociatedFunctionKind::JVP &&
3495
+ if (reorderSelf && linearMapKind == AutoDiffLinearMapKind::Differential &&
3490
3496
thunkArguments.size () > 1 ) {
3497
+ // Before: [ind_res1, ind_res2, ..., arg1, arg2, ..., arg_self, df]
3498
+ // After: [ind_res1, ind_res2, ..., arg_self, arg1, arg2, ..., df]
3491
3499
std::rotate (thunkArguments.begin () + numIndirectResults,
3492
3500
thunkArguments.end () - 2 , thunkArguments.end () - 1 );
3501
+ // Before: [arg1, arg2, ..., arg_self]
3502
+ // After: [arg_self, arg1, arg2, ...]
3493
3503
std::rotate (toParameters.begin (), toParameters.end () - 1 ,
3494
3504
toParameters.end ());
3495
3505
}
@@ -3589,14 +3599,12 @@ SILGenFunction::getThunkedAutoDiffLinearMap(
3589
3599
3590
3600
// Handle self reordering.
3591
3601
// For pullbacks: rotate direct results if self is direct.
3592
- if (reorderSelf && assocFnKind == AutoDiffAssociatedFunctionKind::VJP ) {
3602
+ if (reorderSelf && linearMapKind == AutoDiffLinearMapKind::Pullback ) {
3593
3603
auto fromSelfResult = fromConv.getResults ().front ();
3594
3604
auto toSelfResult = toConv.getResults ().back ();
3595
3605
assert (fromSelfResult.getType () == toSelfResult.getType ());
3596
- if (toSelfResult.isFormalIndirect () && thunkIndirectResults.size () > 1 ) {
3597
- std::rotate (thunkIndirectResults.begin (), thunkIndirectResults.end () - 1 ,
3598
- thunkIndirectResults.end ());
3599
- }
3606
+ // Before: [dir_res_self, dir_res1, dir_res2, ...]
3607
+ // After: [dir_res1, dir_res2, ..., dir_res_self]
3600
3608
if (toSelfResult.isFormalDirect () && fromSelfResult.isFormalDirect () &&
3601
3609
directResults.size () > 1 ) {
3602
3610
std::rotate (directResults.begin (), directResults.begin () + 1 ,
@@ -3802,8 +3810,9 @@ SILGenModule::getOrCreateAutoDiffAssociatedFunctionThunk(
3802
3810
}
3803
3811
3804
3812
// Otherwise, apply reabstraction/self reordering thunk to linear map.
3813
+ auto linearMapKind = assocFnKind.getLinearMapKind ();
3805
3814
linearMap = thunkSGF.getThunkedAutoDiffLinearMap (
3806
- linearMap, assocFnKind , linearMapFnType, targetLinearMapFnType,
3815
+ linearMap, linearMapKind , linearMapFnType, targetLinearMapFnType,
3807
3816
reorderSelf);
3808
3817
3809
3818
// Return original results and thunked differential/pullback.
0 commit comments