@@ -3364,7 +3364,6 @@ static CanSILFunctionType buildWithoutActuallyEscapingThunkType(
3364
3364
// / tuple. Otherwise, add this value directly to `result`.
3365
3365
static void extractAllElements (SILValue val, SILBuilder &builder,
3366
3366
SmallVectorImpl<SILValue> &result) {
3367
- // auto &fn = builder.getFunction();
3368
3367
if (auto tupleType = val->getType ().getAs <TupleType>())
3369
3368
for (auto i : range (tupleType->getNumElements ()))
3370
3369
result.push_back (builder.createTupleExtract (val.getLoc (), val, i));
@@ -3385,10 +3384,11 @@ static SILValue joinElements(ArrayRef<SILValue> elements, SILBuilder &builder,
3385
3384
3386
3385
// SWIFT_ENABLE_TENSORFLOW
3387
3386
// / Adapted from `SILGenModule::getOrCreateReabstractionThunk`.
3388
- SILFunction *
3389
- SILGenFunction::getOrCreateAutoDiffLinearMapThunk (
3390
- AutoDiffAssociatedFunctionKind assocFnKind, CanSILFunctionType fromType,
3391
- CanSILFunctionType toType, bool reorderSelf) {
3387
+ ManagedValue
3388
+ SILGenFunction::getThunkedAutoDiffLinearMap (
3389
+ ManagedValue linearMap, AutoDiffAssociatedFunctionKind assocFnKind,
3390
+ CanSILFunctionType fromType, CanSILFunctionType toType,
3391
+ bool reorderSelf) {
3392
3392
// Compute the thunk type.
3393
3393
SubstitutionMap interfaceSubs;
3394
3394
GenericEnvironment *genericEnv = nullptr ;
@@ -3409,7 +3409,7 @@ SILGenFunction::getOrCreateAutoDiffLinearMapThunk(
3409
3409
std::string name = mangler.mangleReabstractionThunkHelper (
3410
3410
thunkType, fromInterfaceType, toInterfaceType,
3411
3411
Type (), getModule ().getSwiftModule ());
3412
- // TODO: Use principled mangling.
3412
+ // TODO(TF-685) : Use principled thunk mangling.
3413
3413
if (reorderSelf) {
3414
3414
switch (assocFnKind) {
3415
3415
case AutoDiffAssociatedFunctionKind::JVP:
@@ -3428,8 +3428,20 @@ SILGenFunction::getOrCreateAutoDiffLinearMapThunk(
3428
3428
auto *thunk = fb.getOrCreateSharedFunction (
3429
3429
loc, name, thunkDeclType, IsBare, IsTransparent, IsSerialized,
3430
3430
ProfileCounter (), IsReabstractionThunk, IsNotDynamic);
3431
+
3432
+ // Partially-apply the thunk to `linearMap` and return the thunked value.
3433
+ auto getThunkedResult = [&]() {
3434
+ auto thunkedFn = createPartialApplyOfThunk (
3435
+ *this , loc, thunk, interfaceSubs, dynamicSelfType, toType, linearMap);
3436
+ if (!toType->isNoEscape ())
3437
+ return thunkedFn;
3438
+ // Handle escaping to noescape conversion.
3439
+ return B.createConvertEscapeToNoEscape (
3440
+ loc, thunkedFn, SILType::getPrimitiveObjectType (toType));
3441
+ };
3442
+
3431
3443
if (!thunk->empty ())
3432
- return thunk ;
3444
+ return getThunkedResult () ;
3433
3445
thunk->setGenericEnvironment (genericEnv);
3434
3446
thunk->setOwnershipEliminated ();
3435
3447
@@ -3560,9 +3572,9 @@ SILGenFunction::getOrCreateAutoDiffLinearMapThunk(
3560
3572
arguments.push_back (load);
3561
3573
}
3562
3574
3563
- auto linearMap = thunk->getArgumentsWithoutIndirectResults ().back ();
3575
+ auto *linearMapArg = thunk->getArgumentsWithoutIndirectResults ().back ();
3564
3576
auto *apply = thunkSGF.B .createApply (
3565
- loc, linearMap , SubstitutionMap (), arguments, /* isNonThrowing*/ false );
3577
+ loc, linearMapArg , SubstitutionMap (), arguments, /* isNonThrowing*/ false );
3566
3578
3567
3579
// Get return elements.
3568
3580
SmallVector<SILValue, 4 > results;
@@ -3630,7 +3642,7 @@ SILGenFunction::getOrCreateAutoDiffLinearMapThunk(
3630
3642
3631
3643
// Create return.
3632
3644
thunkSGF.B .createReturn (loc, retVal);
3633
- return thunk ;
3645
+ return getThunkedResult () ;
3634
3646
}
3635
3647
3636
3648
// / Forward function arguments, converting ownership.
@@ -3687,9 +3699,12 @@ static void forwardFunctionArgumentsConvertingOwnership(
3687
3699
SILFunction *
3688
3700
SILGenModule::getOrCreateAutoDiffAssociatedFunctionThunk (
3689
3701
SILFunction *original, SILAutoDiffIndices &indices,
3690
- SILFunction *assocFn, AutoDiffAssociatedFunctionKind assocFnKind) {
3702
+ SILFunction *assocFn, AutoDiffAssociatedFunctionKind assocFnKind,
3703
+ bool reorderSelf) {
3691
3704
auto assocFnType = assocFn->getLoweredFunctionType ();
3692
3705
3706
+ // TODO(TF-685): Use principled thunk mangling.
3707
+ // Do not simply reuse reabstraction thunk mangling.
3693
3708
Mangle::ASTMangler mangler;
3694
3709
auto name = getASTContext ().getIdentifier (
3695
3710
mangler.mangleAutoDiffAssociatedFunctionHelper (
@@ -3746,8 +3761,8 @@ SILGenModule::getOrCreateAutoDiffAssociatedFunctionThunk(
3746
3761
3747
3762
SmallVector<SILValue, 8 > directResults;
3748
3763
extractAllElements (apply, thunkSGF.B , directResults);
3749
- auto linearMap = directResults.back ();
3750
- auto linearMapFnType = linearMap-> getType ().castTo <SILFunctionType>();
3764
+ auto linearMap = ManagedValue::forBorrowedObjectRValue ( directResults.back () );
3765
+ auto linearMapFnType = linearMap. getType ().castTo <SILFunctionType>();
3751
3766
auto targetLinearMapFnType = thunk->mapTypeIntoContext (
3752
3767
origAssocFnType->getResults ().back ().getSILStorageType ())
3753
3768
.castTo <SILFunctionType>();
@@ -3769,33 +3784,17 @@ SILGenModule::getOrCreateAutoDiffAssociatedFunctionThunk(
3769
3784
thunkSGF.B .createReturn (loc, retValue);
3770
3785
};
3771
3786
3772
- // If linear map types are unchanged, return the `apply` instruction.
3773
- if (linearMapFnType == targetLinearMapFnType) {
3787
+ // If self ordering is not necessary and linear map types are unchanged,
3788
+ // return the `apply` instruction.
3789
+ if (!reorderSelf && linearMapFnType == targetLinearMapFnType) {
3774
3790
createReturn (apply);
3775
3791
return thunk;
3776
3792
}
3777
3793
3778
- // Generate linear map thunk for reabstraction/self reordering.
3779
- auto shouldReorderSelf = [&]() {
3780
- if (!original->hasSelfParam ())
3781
- return false ;
3782
- auto selfParamIndex =
3783
- original->getArgumentsWithoutIndirectResults ().size () - 1 ;
3784
- if (!indices.isWrtParameter (selfParamIndex))
3785
- return false ;
3786
- return indices.parameters ->getNumIndices () > 1 ;
3787
- };
3788
- bool reorderSelf = shouldReorderSelf ();
3789
- auto *linearMapThunk = thunkSGF.getOrCreateAutoDiffLinearMapThunk (
3790
- assocFnKind, linearMapFnType, targetLinearMapFnType, reorderSelf);
3791
- auto linearMapThunkValue =
3792
- thunkSGF.B .createFunctionRefFor (loc, linearMapThunk);
3793
- SubstitutionMap linearMapSubs;
3794
- if (linearMapThunk->getLoweredFunctionType ()->isPolymorphic ())
3795
- linearMapSubs = thunk->getForwardingSubstitutionMap ();
3796
- linearMap = thunkSGF.B .createPartialApply (
3797
- loc, linearMapThunkValue, linearMapSubs, {linearMap},
3798
- linearMapFnType->getCalleeConvention ());
3794
+ // Otherwise, apply reabstraction/self reordering thunk to linear map.
3795
+ linearMap = thunkSGF.getThunkedAutoDiffLinearMap (
3796
+ linearMap, assocFnKind, linearMapFnType, targetLinearMapFnType,
3797
+ reorderSelf);
3799
3798
3800
3799
// Return original results and thunked differential/pullback.
3801
3800
if (directResults.size () > 1 ) {
@@ -3804,10 +3803,10 @@ SILGenModule::getOrCreateAutoDiffAssociatedFunctionThunk(
3804
3803
auto originalDirectResult =
3805
3804
joinElements (originalDirectResults, thunkSGF.B , apply.getLoc ());
3806
3805
auto thunkResult = joinElements (
3807
- {originalDirectResult, linearMap}, thunkSGF.B , loc);
3806
+ {originalDirectResult, linearMap. getValue () }, thunkSGF.B , loc);
3808
3807
createReturn (thunkResult);
3809
3808
} else {
3810
- createReturn (linearMap);
3809
+ createReturn (linearMap. getValue () );
3811
3810
}
3812
3811
return thunk;
3813
3812
}
0 commit comments