@@ -752,87 +752,135 @@ void SILGenModule::postEmitFunction(SILDeclRef constant,
752
752
F->print (llvm::dbgs ()));
753
753
754
754
// SWIFT_ENABLE_TENSORFLOW
755
- // Create self-reordering thunks for JVPs/VJPs of `@differentiable` methods.
756
- if (constant.hasDecl () && constant.getAbstractFunctionDecl ()) {
755
+ // Visit `@differentiable` attributes and generate SIL differentiability
756
+ // witnesses.
757
+ // TODO(TF-835): Visit `@differentiating` attributes when type-checking no
758
+ // longer generates implicit `@differentiable` attributes. See TF-835 for
759
+ // replacement code.
760
+ // Skip if the SILDeclRef is a:
761
+ // - Default argument generator function.
762
+ // - Thunk.
763
+ if (constant.hasDecl () && constant.getAbstractFunctionDecl () &&
764
+ constant.kind != SILDeclRef::Kind::DefaultArgGenerator &&
765
+ !constant.isThunk ()) {
757
766
auto *AFD = constant.getAbstractFunctionDecl ();
758
- auto origFnType = AFD->getInterfaceType ()->castTo <AnyFunctionType>();
759
- auto origSilFnType = F->getLoweredFunctionType ();
760
- // Jointly iterate over AST `@differentiable` attributes and SIL
761
- // `[differentiable]` attributes.
762
- auto diffAttrs = AFD->getAttrs ().getAttributes <DifferentiableAttr>();
763
- auto silDiffAttrs = F->getDifferentiableAttrs ();
764
- for (auto pair : llvm::zip (diffAttrs, silDiffAttrs)) {
765
- auto *diffAttr = const_cast <DifferentiableAttr *>(std::get<0 >(pair));
766
- auto *silDiffAttr = std::get<1 >(pair);
767
- // Compute lowered parameter indices.
768
- auto *paramIndices = diffAttr->getParameterIndices ();
769
- auto *loweredParamIndices = autodiff::getLoweredParameterIndices (
770
- paramIndices, origFnType);
771
- SILAutoDiffIndices indices (/* source*/ 0 , loweredParamIndices);
772
- assert (silDiffAttr->getIndices () == indices &&
773
- " Expected matching @differentiable and [differentiable] indices" );
774
-
775
- auto lookUpConformance = LookUpConformanceInModule (M.getSwiftModule ());
776
- auto expectedJVPType = origSilFnType->getAutoDiffDerivativeFunctionType (
777
- indices.parameters , indices.source ,
778
- AutoDiffDerivativeFunctionKind::JVP, Types, lookUpConformance);
779
- auto expectedVJPType = origSilFnType->getAutoDiffDerivativeFunctionType (
780
- indices.parameters , indices.source ,
781
- AutoDiffDerivativeFunctionKind::VJP, Types, lookUpConformance);
782
-
783
- // Self reordering is necessary if wrt at least two parameters, including
784
- // self.
785
- auto shouldReorderSelf = [&]() {
786
- if (!F->hasSelfParam ())
787
- return false ;
788
- auto selfParamIndex = origSilFnType->getNumParameters () - 1 ;
789
- if (!indices.isWrtParameter (selfParamIndex))
790
- return false ;
791
- return indices.parameters ->getNumIndices () > 1 ;
792
- };
793
- bool reorderSelf = shouldReorderSelf ();
794
-
795
- // Thunk JVP method, if it is defined.
796
- if (auto *jvpDecl = diffAttr->getJVPFunction ()) {
797
- SILFunction *jvpThunk;
798
- auto *jvpFn = getFunction (SILDeclRef (jvpDecl), NotForDefinition);
799
- if (reorderSelf || jvpFn->getLoweredFunctionType () != expectedJVPType) {
800
- jvpThunk = getOrCreateAutoDiffDerivativeFunctionThunk (
801
- F, indices, jvpFn, AutoDiffDerivativeFunctionKind::JVP,
802
- reorderSelf);
803
- } else {
804
- auto *id = AutoDiffDerivativeFunctionIdentifier::get (
805
- AutoDiffDerivativeFunctionKind::JVP,
806
- diffAttr->getParameterIndices (), AFD->getASTContext ());
807
- jvpThunk = getOrCreateAutoDiffThunk (
808
- constant.asAutoDiffDerivativeFunction (id), jvpFn,
809
- expectedJVPType);
810
- }
811
- silDiffAttr->setJVPName (jvpThunk->getName ());
812
- }
813
- // Thunk VJP method, if it is defined.
814
- if (auto *vjpDecl = diffAttr->getVJPFunction ()) {
815
- SILFunction *vjpThunk;
816
- auto *vjpFn = getFunction (SILDeclRef (vjpDecl), NotForDefinition);
817
- if (reorderSelf || vjpFn->getLoweredFunctionType () != expectedVJPType) {
818
- vjpThunk = getOrCreateAutoDiffDerivativeFunctionThunk (
819
- F, indices, vjpFn, AutoDiffDerivativeFunctionKind::VJP,
820
- reorderSelf);
821
- } else {
822
- auto *id = AutoDiffDerivativeFunctionIdentifier::get (
823
- AutoDiffDerivativeFunctionKind::VJP,
824
- diffAttr->getParameterIndices (), AFD->getASTContext ());
825
- vjpThunk = getOrCreateAutoDiffThunk (
826
- constant.asAutoDiffDerivativeFunction (id), vjpFn,
827
- expectedVJPType);
828
- }
829
- silDiffAttr->setVJPName (vjpThunk->getName ());
830
- }
767
+ // Visit all `@differentiable` attributes.
768
+ for (auto *diffAttr : AFD->getAttrs ().getAttributes <DifferentiableAttr>()) {
769
+ SILFunction *jvp = nullptr ;
770
+ SILFunction *vjp = nullptr ;
771
+ if (auto *jvpDecl = diffAttr->getJVPFunction ())
772
+ jvp = getFunction (SILDeclRef (jvpDecl), NotForDefinition);
773
+ if (auto *vjpDecl = diffAttr->getVJPFunction ())
774
+ vjp = getFunction (SILDeclRef (vjpDecl), NotForDefinition);
775
+ auto *resultIndices = IndexSubset::get (getASTContext (), 1 , {0 });
776
+ AutoDiffConfig config{diffAttr->getParameterIndices (), resultIndices,
777
+ diffAttr->getDerivativeGenericSignature ()};
778
+ emitDifferentiabilityWitness (AFD, F, config, jvp, vjp);
831
779
}
832
780
}
833
781
F->verify ();
834
782
}
835
783
784
+ void SILGenModule::emitDifferentiabilityWitness (
785
+ AbstractFunctionDecl *originalAFD, SILFunction *originalFunction,
786
+ const AutoDiffConfig &config, SILFunction *jvp, SILFunction *vjp) {
787
+ auto *origFnType = originalAFD->getInterfaceType ()->castTo <AnyFunctionType>();
788
+ auto origSilFnType = originalFunction->getLoweredFunctionType ();
789
+ auto *loweredParamIndices = autodiff::getLoweredParameterIndices (
790
+ config.parameterIndices , origFnType);
791
+ // NOTE(TF-893): Extending capacity is necessary when `origSilFnType` has
792
+ // parameters corresponding to captured variables. These parameters do not
793
+ // appear in the type of `origFnType`.
794
+ // TODO: If posssible, change `autodiff::getLoweredParameterIndices` to
795
+ // take `CaptureInfo` into account.
796
+ if (origSilFnType->getNumParameters () > loweredParamIndices->getCapacity ())
797
+ loweredParamIndices = loweredParamIndices->extendingCapacity (
798
+ getASTContext (), origSilFnType->getNumParameters ());
799
+ // TODO(TF-913): Replace usages of `SILAutoDiffIndices` with `AutoDiffConfig`.
800
+ SILAutoDiffIndices indices (/* source*/ 0 , loweredParamIndices);
801
+
802
+ // Self reordering thunk is necessary if wrt at least two parameters,
803
+ // including self.
804
+ auto shouldReorderSelf = [&]() {
805
+ if (!originalFunction->hasSelfParam ())
806
+ return false ;
807
+ auto selfParamIndex = origSilFnType->getNumParameters () - 1 ;
808
+ if (!indices.isWrtParameter (selfParamIndex))
809
+ return false ;
810
+ return indices.parameters ->getNumIndices () > 1 ;
811
+ };
812
+ bool reorderSelf = shouldReorderSelf ();
813
+
814
+ CanGenericSignature derivativeCanGenSig;
815
+ if (auto *derivativeGenSig = config.derivativeGenericSignature )
816
+ derivativeCanGenSig = derivativeGenSig->getCanonicalSignature ();
817
+ // TODO(TF-835): Use simpler derivative generic signature logic below when
818
+ // type-checking no longer generates implicit `@differentiable` attributes.
819
+ // See TF-835 for replacement code.
820
+ if (jvp) {
821
+ auto jvpCanGenSig = jvp->getLoweredFunctionType ()->getGenericSignature ();
822
+ if (!derivativeCanGenSig && jvpCanGenSig)
823
+ derivativeCanGenSig = jvpCanGenSig;
824
+ assert (derivativeCanGenSig == jvpCanGenSig);
825
+ }
826
+ if (vjp) {
827
+ auto vjpCanGenSig = vjp->getLoweredFunctionType ()->getGenericSignature ();
828
+ if (!derivativeCanGenSig && vjpCanGenSig)
829
+ derivativeCanGenSig = vjpCanGenSig;
830
+ assert (derivativeCanGenSig == vjpCanGenSig);
831
+ }
832
+ // Create new SIL differentiability witness.
833
+ // Witness JVP and VJP are set below.
834
+ // TODO(TF-919): Explore creating serialized differentiability witnesses.
835
+ // Currently, differentiability witnesses are never serialized to avoid
836
+ // deserialization issues where JVP/VJP functions cannot be found.
837
+ auto *diffWitness = SILDifferentiabilityWitness::create (
838
+ M, originalFunction->getLinkage (), originalFunction,
839
+ loweredParamIndices, config.resultIndices , derivativeCanGenSig,
840
+ /* jvp*/ nullptr , /* vjp*/ nullptr , /* isSerialized*/ false );
841
+
842
+ // Set derivative function in differentiability witness.
843
+ auto setDerivativeInDifferentiabilityWitness =
844
+ [&](AutoDiffDerivativeFunctionKind kind, SILFunction *derivative) {
845
+ auto expectedDerivativeType =
846
+ origSilFnType->getAutoDiffDerivativeFunctionType (
847
+ indices.parameters , indices.source , kind, Types,
848
+ LookUpConformanceInModule (M.getSwiftModule ()));
849
+ // Thunk derivative function.
850
+ SILFunction *derivativeThunk;
851
+ if (reorderSelf ||
852
+ derivative->getLoweredFunctionType () != expectedDerivativeType) {
853
+ derivativeThunk = getOrCreateAutoDiffDerivativeFunctionThunk (
854
+ originalFunction, indices, derivative, kind, reorderSelf);
855
+ } else {
856
+ // Note: `AutoDiffDerivativeFunctionIdentifier` must be constructed with
857
+ // the AST-level parameter indices, not the SIL-level ones.
858
+ auto *id = AutoDiffDerivativeFunctionIdentifier::get (
859
+ kind, config.parameterIndices , getASTContext ());
860
+ derivativeThunk = getOrCreateAutoDiffThunk (
861
+ SILDeclRef (originalAFD).asAutoDiffDerivativeFunction (id), derivative,
862
+ expectedDerivativeType);
863
+ }
864
+ // Check for existing same derivative.
865
+ // TODO(TF-835): Remove condition below and simplify assertion to
866
+ // `!diffWitness->getDerivative(kind)` after `@differentiating` attribute
867
+ // type-checking no longer generates implicit `@differentiable` attributes.
868
+ auto *existingDerivative = diffWitness->getDerivative (kind);
869
+ if (existingDerivative && existingDerivative == derivativeThunk)
870
+ return ;
871
+ assert (!existingDerivative &&
872
+ " SIL differentiability witness already has a different existing "
873
+ " derivative" );
874
+ diffWitness->setDerivative (kind, derivativeThunk);
875
+ };
876
+ if (jvp)
877
+ setDerivativeInDifferentiabilityWitness (AutoDiffDerivativeFunctionKind::JVP,
878
+ jvp);
879
+ if (vjp)
880
+ setDerivativeInDifferentiabilityWitness (AutoDiffDerivativeFunctionKind::VJP,
881
+ vjp);
882
+ }
883
+
836
884
void SILGenModule::
837
885
emitMarkFunctionEscapeForTopLevelCodeGlobals (SILLocation loc,
838
886
const CaptureInfo &captureInfo) {
0 commit comments