@@ -3715,66 +3715,65 @@ SILGenFunction::getThunkedAutoDiffLinearMap(
3715
3715
}
3716
3716
3717
3717
// SWIFT_ENABLE_TENSORFLOW
3718
- SILFunction *SILGenModule::getOrCreateAutoDiffDerivativeReabstractionThunk (
3719
- SILFunction *original, AutoDiffConfig config, SILFunction *derivativeFn,
3720
- AutoDiffDerivativeFunctionKind derivativeFnKind, bool reorderSelf) {
3721
- auto derivativeFnType = derivativeFn->getLoweredFunctionType ();
3722
-
3723
- // TODO(TF-685): Use principled thunk mangling.
3724
- // Do not simply reuse reabstraction thunk mangling.
3725
- Mangle::ASTMangler mangler;
3726
- auto name = getASTContext ()
3727
- .getIdentifier (mangler.mangleAutoDiffDerivativeFunctionHelper (
3728
- original->getName (), derivativeFnKind, config))
3729
- .str ();
3730
- auto *thunkGenericEnv = derivativeFnType->getSubstGenericSignature ()
3731
- ? derivativeFnType->getSubstGenericSignature ()->getGenericEnvironment ()
3718
+ SILFunction *SILGenModule::getOrCreateCustomDerivativeThunk (
3719
+ SILFunction *customDerivativeFn, SILFunction *originalFn,
3720
+ const AutoDiffConfig &config, AutoDiffDerivativeFunctionKind kind) {
3721
+ auto indices = config.getSILAutoDiffIndices ();
3722
+
3723
+ auto customDerivativeFnTy = customDerivativeFn->getLoweredFunctionType ();
3724
+ auto *thunkGenericEnv = customDerivativeFnTy->getSubstGenericSignature ()
3725
+ ? customDerivativeFnTy->getSubstGenericSignature ()->getGenericEnvironment ()
3732
3726
: nullptr ;
3733
3727
3734
- auto origFnType = original->getLoweredFunctionType ();
3735
- assert (config.resultIndices ->getNumIndices () == 1 &&
3736
- " Only single result index is currently supported" );
3728
+ auto origFnTy = originalFn->getLoweredFunctionType ();
3737
3729
CanGenericSignature derivativeCanGenSig;
3738
3730
if (auto derivativeGenSig = config.derivativeGenericSignature )
3739
3731
derivativeCanGenSig = derivativeGenSig->getCanonicalSignature ();
3740
- auto origDerivativeFnType = origFnType ->getAutoDiffDerivativeFunctionType (
3741
- config. parameterIndices , *config. resultIndices -> getIndices (). begin () ,
3742
- derivativeFnKind , Types, LookUpConformanceInModule (M.getSwiftModule ()),
3732
+ auto thunkFnTy = origFnTy ->getAutoDiffDerivativeFunctionType (
3733
+ indices. parameters , indices. source ,
3734
+ kind , Types, LookUpConformanceInModule (M.getSwiftModule ()),
3743
3735
derivativeCanGenSig);
3744
- assert (!origDerivativeFnType->getExtInfo ().hasContext ());
3736
+ assert (!thunkFnTy->getExtInfo ().hasContext ());
3737
+
3738
+ // TODO(TF-685): Use principled thunk mangling.
3739
+ // Do not simply reuse reabstraction thunk mangling.
3740
+ Mangle::ASTMangler mangler;
3741
+ auto name = getASTContext ().getIdentifier (
3742
+ mangler.mangleAutoDiffDerivativeFunctionHelper (
3743
+ originalFn->getName (), kind, config)).str ();
3745
3744
3746
- auto loc = derivativeFn ->getLocation ();
3745
+ auto loc = customDerivativeFn ->getLocation ();
3747
3746
SILGenFunctionBuilder fb (*this );
3748
3747
// This thunk is publicly exposed and cannot be transparent.
3749
3748
// Instead, mark it as "always inline" for optimization.
3750
3749
auto *thunk = fb.getOrCreateFunction (
3751
- loc, name, original ->getLinkage (), origDerivativeFnType , IsBare,
3752
- IsNotTransparent, derivativeFn ->isSerialized (),
3753
- derivativeFn ->isDynamicallyReplaceable (), derivativeFn ->getEntryCount (),
3754
- derivativeFn-> isThunk (), derivativeFn ->getClassSubclassScope ());
3750
+ loc, name, customDerivativeFn ->getLinkage (), thunkFnTy , IsBare,
3751
+ IsNotTransparent, customDerivativeFn ->isSerialized (),
3752
+ customDerivativeFn ->isDynamicallyReplaceable (), customDerivativeFn ->getEntryCount (),
3753
+ IsThunk, customDerivativeFn ->getClassSubclassScope ());
3755
3754
thunk->setInlineStrategy (AlwaysInline);
3756
3755
if (!thunk->empty ())
3757
3756
return thunk;
3758
3757
thunk->setGenericEnvironment (thunkGenericEnv);
3759
3758
3760
- SILGenFunction thunkSGF (*this , *thunk, derivativeFn ->getDeclContext ());
3759
+ SILGenFunction thunkSGF (*this , *thunk, customDerivativeFn ->getDeclContext ());
3761
3760
SmallVector<ManagedValue, 4 > params;
3762
3761
SmallVector<SILArgument *, 4 > indirectResults;
3763
3762
thunkSGF.collectThunkParams (loc, params, &indirectResults);
3764
3763
3765
- auto *derivativeFnRef = thunkSGF.B .createFunctionRef (loc, derivativeFn );
3766
- auto derivativeFnRefType =
3767
- derivativeFnRef ->getType ().castTo <SILFunctionType>();
3764
+ auto *fnRef = thunkSGF.B .createFunctionRef (loc, customDerivativeFn );
3765
+ auto fnRefType =
3766
+ fnRef ->getType ().castTo <SILFunctionType>();
3768
3767
3769
3768
// Collect thunk arguments, converting ownership.
3770
3769
SmallVector<SILValue, 8 > arguments;
3771
3770
for (auto *indRes : indirectResults)
3772
3771
arguments.push_back (indRes);
3773
- forwardFunctionArguments (thunkSGF, loc, derivativeFnRefType , params,
3772
+ forwardFunctionArguments (thunkSGF, loc, fnRefType , params,
3774
3773
arguments);
3775
3774
// Apply function argument.
3776
3775
auto apply = thunkSGF.emitApplyWithRethrow (
3777
- loc, derivativeFnRef , /* substFnType*/ derivativeFnRef ->getType (),
3776
+ loc, fnRef , /* substFnType*/ fnRef ->getType (),
3778
3777
thunk->getForwardingSubstitutionMap (), arguments);
3779
3778
3780
3779
// Create return instruction in the thunk, first deallocating local
@@ -3787,15 +3786,27 @@ SILFunction *SILGenModule::getOrCreateAutoDiffDerivativeReabstractionThunk(
3787
3786
thunkSGF.B .createReturn (loc, retValue);
3788
3787
};
3789
3788
3789
+ // Self reordering thunk is necessary if wrt at least two parameters,
3790
+ // including self.
3791
+ auto shouldReorderSelf = [&]() {
3792
+ if (!originalFn->hasSelfParam ())
3793
+ return false ;
3794
+ auto selfParamIndex = origFnTy->getNumParameters () - 1 ;
3795
+ if (!indices.isWrtParameter (selfParamIndex))
3796
+ return false ;
3797
+ return indices.parameters ->getNumIndices () > 1 ;
3798
+ };
3799
+ bool reorderSelf = shouldReorderSelf ();
3800
+
3790
3801
// If self ordering is not necessary and linear map types are unchanged,
3791
3802
// return the `apply` instruction.
3792
3803
auto linearMapFnType = cast<SILFunctionType>(
3793
3804
thunk
3794
3805
->mapTypeIntoContext (
3795
- derivativeFnRefType ->getResults ().back ().getInterfaceType ())
3806
+ fnRefType ->getResults ().back ().getInterfaceType ())
3796
3807
->getCanonicalType ());
3797
3808
auto targetLinearMapFnType = thunk->mapTypeIntoContext (
3798
- origDerivativeFnType ->getResults ().back ().getSILStorageInterfaceType ())
3809
+ thunkFnTy ->getResults ().back ().getSILStorageInterfaceType ())
3799
3810
.castTo <SILFunctionType>();
3800
3811
if (!reorderSelf && linearMapFnType == targetLinearMapFnType) {
3801
3812
createReturn (apply);
@@ -3807,7 +3818,7 @@ SILFunction *SILGenModule::getOrCreateAutoDiffDerivativeReabstractionThunk(
3807
3818
extractAllElements (apply, loc, thunkSGF.B , directResults);
3808
3819
auto linearMap = thunkSGF.emitManagedRValueWithCleanup (directResults.back ());
3809
3820
assert (linearMap.getType ().castTo <SILFunctionType>() == linearMapFnType);
3810
- auto linearMapKind = derivativeFnKind .getLinearMapKind ();
3821
+ auto linearMapKind = kind .getLinearMapKind ();
3811
3822
linearMap = thunkSGF.getThunkedAutoDiffLinearMap (
3812
3823
linearMap, linearMapKind, linearMapFnType, targetLinearMapFnType,
3813
3824
reorderSelf);
0 commit comments