@@ -3691,64 +3691,65 @@ SILGenFunction::getThunkedAutoDiffLinearMap(
3691
3691
3692
3692
// SWIFT_ENABLE_TENSORFLOW
3693
3693
SILFunction *
3694
- SILGenModule::getOrCreateAutoDiffDerivativeReabstractionThunk (
3695
- SILFunction *original, SILAutoDiffIndices &indices,
3696
- SILFunction *derivativeFn, AutoDiffDerivativeFunctionKind derivativeFnKind,
3697
- bool reorderSelf) {
3698
- auto derivativeFnType = derivativeFn->getLoweredFunctionType ();
3694
+ SILGenModule::getOrCreateCustomDerivativeThunk (
3695
+ SILFunction *customDerivativeFn,
3696
+ SILFunction *originalFn, const AutoDiffConfig &config,
3697
+ AutoDiffDerivativeFunctionKind kind) {
3698
+ auto indices = config.getSILAutoDiffIndices ();
3699
+ auto customDerivativeFnTy = customDerivativeFn->getLoweredFunctionType ();
3700
+
3701
+ Lowering::GenericContextScope genericContextScope (
3702
+ Types, customDerivativeFnTy->getSubstGenericSignature ());
3703
+ auto *thunkGenericEnv = customDerivativeFnTy->getSubstGenericSignature ()
3704
+ ? customDerivativeFnTy->getSubstGenericSignature ()->getGenericEnvironment ()
3705
+ : nullptr ;
3706
+
3707
+ auto origFnTy = originalFn->getLoweredFunctionType ();
3708
+ auto thunkFnTy = origFnTy->getAutoDiffDerivativeFunctionType (
3709
+ indices.parameters , indices.source ,
3710
+ kind, Types, LookUpConformanceInModule (M.getSwiftModule ()),
3711
+ customDerivativeFnTy->getSubstGenericSignature ());
3712
+ assert (!thunkFnTy->getExtInfo ().hasContext ());
3699
3713
3700
3714
// TODO(TF-685): Use principled thunk mangling.
3701
3715
// Do not simply reuse reabstraction thunk mangling.
3702
3716
Mangle::ASTMangler mangler;
3703
3717
auto name = getASTContext ().getIdentifier (
3704
3718
mangler.mangleAutoDiffDerivativeFunctionHelper (
3705
- original ->getName (), derivativeFnKind , indices)).str ();
3719
+ originalFn ->getName (), kind , indices)).str ();
3706
3720
3707
- Lowering::GenericContextScope genericContextScope (
3708
- Types, derivativeFnType->getSubstGenericSignature ());
3709
- auto *thunkGenericEnv = derivativeFnType->getSubstGenericSignature ()
3710
- ? derivativeFnType->getSubstGenericSignature ()->getGenericEnvironment ()
3711
- : nullptr ;
3712
-
3713
- auto origFnType = original->getLoweredFunctionType ();
3714
- auto origDerivativeFnType = origFnType->getAutoDiffDerivativeFunctionType (
3715
- indices.parameters , indices.source ,
3716
- derivativeFnKind, Types, LookUpConformanceInModule (M.getSwiftModule ()),
3717
- derivativeFnType->getSubstGenericSignature ());
3718
- assert (!origDerivativeFnType->getExtInfo ().hasContext ());
3719
-
3720
- auto loc = derivativeFn->getLocation ();
3721
+ auto loc = customDerivativeFn->getLocation ();
3721
3722
SILGenFunctionBuilder fb (*this );
3722
3723
// This thunk is publicly exposed and cannot be transparent.
3723
3724
// Instead, mark it as "always inline" for optimization.
3724
3725
auto *thunk = fb.getOrCreateFunction (
3725
- loc, name, original ->getLinkage (), origDerivativeFnType , IsBare,
3726
- IsNotTransparent, derivativeFn ->isSerialized (),
3727
- derivativeFn ->isDynamicallyReplaceable (), derivativeFn ->getEntryCount (),
3728
- derivativeFn-> isThunk (), derivativeFn ->getClassSubclassScope ());
3726
+ loc, name, customDerivativeFn ->getLinkage (), thunkFnTy , IsBare,
3727
+ IsNotTransparent, customDerivativeFn ->isSerialized (),
3728
+ customDerivativeFn ->isDynamicallyReplaceable (), customDerivativeFn ->getEntryCount (),
3729
+ IsThunk, customDerivativeFn ->getClassSubclassScope ());
3729
3730
thunk->setInlineStrategy (AlwaysInline);
3730
3731
if (!thunk->empty ())
3731
3732
return thunk;
3732
3733
thunk->setGenericEnvironment (thunkGenericEnv);
3733
3734
3734
- SILGenFunction thunkSGF (*this , *thunk, derivativeFn ->getDeclContext ());
3735
+ SILGenFunction thunkSGF (*this , *thunk, customDerivativeFn ->getDeclContext ());
3735
3736
SmallVector<ManagedValue, 4 > params;
3736
3737
SmallVector<SILArgument *, 4 > indirectResults;
3737
3738
thunkSGF.collectThunkParams (loc, params, &indirectResults);
3738
3739
3739
- auto *derivativeFnRef = thunkSGF.B .createFunctionRef (loc, derivativeFn );
3740
- auto derivativeFnRefType =
3741
- derivativeFnRef ->getType ().castTo <SILFunctionType>();
3740
+ auto *fnRef = thunkSGF.B .createFunctionRef (loc, customDerivativeFn );
3741
+ auto fnRefType =
3742
+ fnRef ->getType ().castTo <SILFunctionType>();
3742
3743
3743
3744
// Collect thunk arguments, converting ownership.
3744
3745
SmallVector<SILValue, 8 > arguments;
3745
3746
for (auto *indRes : indirectResults)
3746
3747
arguments.push_back (indRes);
3747
- forwardFunctionArguments (thunkSGF, loc, derivativeFnRefType , params,
3748
+ forwardFunctionArguments (thunkSGF, loc, fnRefType , params,
3748
3749
arguments);
3749
3750
// Apply function argument.
3750
3751
auto apply = thunkSGF.emitApplyWithRethrow (
3751
- loc, derivativeFnRef , /* substFnType*/ derivativeFnRef ->getType (),
3752
+ loc, fnRef , /* substFnType*/ fnRef ->getType (),
3752
3753
thunk->getForwardingSubstitutionMap (), arguments);
3753
3754
3754
3755
// Create return instruction in the thunk, first deallocating local
@@ -3761,15 +3762,27 @@ SILGenModule::getOrCreateAutoDiffDerivativeReabstractionThunk(
3761
3762
thunkSGF.B .createReturn (loc, retValue);
3762
3763
};
3763
3764
3765
+ // Self reordering thunk is necessary if wrt at least two parameters,
3766
+ // including self.
3767
+ auto shouldReorderSelf = [&]() {
3768
+ if (!originalFn->hasSelfParam ())
3769
+ return false ;
3770
+ auto selfParamIndex = origFnTy->getNumParameters () - 1 ;
3771
+ if (!indices.isWrtParameter (selfParamIndex))
3772
+ return false ;
3773
+ return indices.parameters ->getNumIndices () > 1 ;
3774
+ };
3775
+ bool reorderSelf = shouldReorderSelf ();
3776
+
3764
3777
// If self ordering is not necessary and linear map types are unchanged,
3765
3778
// return the `apply` instruction.
3766
3779
auto linearMapFnType = cast<SILFunctionType>(
3767
3780
thunk
3768
3781
->mapTypeIntoContext (
3769
- derivativeFnRefType ->getResults ().back ().getInterfaceType ())
3782
+ fnRefType ->getResults ().back ().getInterfaceType ())
3770
3783
->getCanonicalType ());
3771
3784
auto targetLinearMapFnType = thunk->mapTypeIntoContext (
3772
- origDerivativeFnType ->getResults ().back ().getSILStorageInterfaceType ())
3785
+ thunkFnTy ->getResults ().back ().getSILStorageInterfaceType ())
3773
3786
.castTo <SILFunctionType>();
3774
3787
if (!reorderSelf && linearMapFnType == targetLinearMapFnType) {
3775
3788
createReturn (apply);
@@ -3781,7 +3794,7 @@ SILGenModule::getOrCreateAutoDiffDerivativeReabstractionThunk(
3781
3794
extractAllElements (apply, loc, thunkSGF.B , directResults);
3782
3795
auto linearMap = thunkSGF.emitManagedRValueWithCleanup (directResults.back ());
3783
3796
assert (linearMap.getType ().castTo <SILFunctionType>() == linearMapFnType);
3784
- auto linearMapKind = derivativeFnKind .getLinearMapKind ();
3797
+ auto linearMapKind = kind .getLinearMapKind ();
3785
3798
linearMap = thunkSGF.getThunkedAutoDiffLinearMap (
3786
3799
linearMap, linearMapKind, linearMapFnType, targetLinearMapFnType,
3787
3800
reorderSelf);
0 commit comments