@@ -751,6 +751,132 @@ void SILGenModule::postEmitFunction(SILDeclRef constant,
751
751
LLVM_DEBUG (llvm::dbgs () << " lowered sil:\n " ;
752
752
F->print (llvm::dbgs ()));
753
753
F->verify ();
754
+
755
+ emitDifferentiabilityWitnessesForFunction (constant, F);
756
+ }
757
+
758
+ void SILGenModule::emitDifferentiabilityWitnessesForFunction (
759
+ SILDeclRef constant, SILFunction *F) {
760
+ // Visit `@differentiable` amd `@derivative` attributes and generate SIL
761
+ // differentiability witnesses.
762
+ // Skip if the SILDeclRef is a:
763
+ // - Default argument generator function.
764
+ // - Thunk.
765
+ if (!constant.hasDecl () || !constant.getAbstractFunctionDecl ())
766
+ return ;
767
+ if (constant.kind == SILDeclRef::Kind::DefaultArgGenerator ||
768
+ constant.isThunk ())
769
+ return ;
770
+ auto *AFD = constant.getAbstractFunctionDecl ();
771
+ auto emitWitnesses = [&](DeclAttributes &Attrs) {
772
+ for (auto *diffAttr : Attrs.getAttributes <DifferentiableAttr>()) {
773
+ SILFunction *jvp = nullptr ;
774
+ SILFunction *vjp = nullptr ;
775
+ if (auto *jvpDecl = diffAttr->getJVPFunction ())
776
+ jvp = getFunction (SILDeclRef (jvpDecl), ForDefinition);
777
+ if (auto *vjpDecl = diffAttr->getVJPFunction ())
778
+ vjp = getFunction (SILDeclRef (vjpDecl), ForDefinition);
779
+ auto *resultIndices = IndexSubset::get (getASTContext (), 1 , {0 });
780
+ assert ((!F->getLoweredFunctionType ()->getSubstGenericSignature () ||
781
+ diffAttr->getDerivativeGenericSignature ()) &&
782
+ " Type-checking should resolve derivative generic signatures for "
783
+ " all original SIL functions with generic signatures" );
784
+ AutoDiffConfig config (diffAttr->getParameterIndices (), resultIndices,
785
+ diffAttr->getDerivativeGenericSignature ());
786
+ emitDifferentiabilityWitness (AFD, F, config, jvp, vjp, diffAttr);
787
+ }
788
+ for (auto *derivAttr : Attrs.getAttributes <DerivativeAttr>()) {
789
+ SILFunction *jvp = nullptr ;
790
+ SILFunction *vjp = nullptr ;
791
+ switch (derivAttr->getDerivativeKind ()) {
792
+ case AutoDiffDerivativeFunctionKind::JVP:
793
+ jvp = F;
794
+ break ;
795
+ case AutoDiffDerivativeFunctionKind::VJP:
796
+ vjp = F;
797
+ break ;
798
+ }
799
+ auto *origAFD = derivAttr->getOriginalFunction ();
800
+ auto origDeclRef =
801
+ SILDeclRef (origAFD).asForeign (requiresForeignEntryPoint (origAFD));
802
+ auto *origFn = getFunction (origDeclRef, NotForDefinition);
803
+ auto derivativeGenSig = AFD->getGenericSignature ();
804
+ auto *resultIndices = IndexSubset::get (getASTContext (), 1 , {0 });
805
+ AutoDiffConfig config (derivAttr->getParameterIndices (), resultIndices,
806
+ derivativeGenSig);
807
+ emitDifferentiabilityWitness (origAFD, origFn, config, jvp, vjp,
808
+ derivAttr);
809
+ }
810
+ };
811
+ if (auto *accessor = dyn_cast<AccessorDecl>(AFD))
812
+ if (accessor->isGetter ())
813
+ emitWitnesses (accessor->getStorage ()->getAttrs ());
814
+ emitWitnesses (AFD->getAttrs ());
815
+ }
816
+
817
+ void SILGenModule::emitDifferentiabilityWitness (
818
+ AbstractFunctionDecl *originalAFD, SILFunction *originalFunction,
819
+ const AutoDiffConfig &config, SILFunction *jvp, SILFunction *vjp,
820
+ const DeclAttribute *attr) {
821
+ assert (isa<DifferentiableAttr>(attr) || isa<DerivativeAttr>(attr));
822
+ auto *origFnType = originalAFD->getInterfaceType ()->castTo <AnyFunctionType>();
823
+ auto origSilFnType = originalFunction->getLoweredFunctionType ();
824
+ auto *silParamIndices =
825
+ autodiff::getLoweredParameterIndices (config.parameterIndices , origFnType);
826
+ // NOTE(TF-893): Extending capacity is necessary when `origSilFnType` has
827
+ // parameters corresponding to captured variables. These parameters do not
828
+ // appear in the type of `origFnType`.
829
+ // TODO: If posssible, change `autodiff::getLoweredParameterIndices` to
830
+ // take `CaptureInfo` into account.
831
+ if (origSilFnType->getNumParameters () > silParamIndices->getCapacity ())
832
+ silParamIndices = silParamIndices->extendingCapacity (
833
+ getASTContext (), origSilFnType->getNumParameters ());
834
+
835
+ // Get or create new SIL differentiability witness.
836
+ // Witness already exists when there are two `@derivative` attributes
837
+ // (registering JVP and VJP functions) for the same derivative function
838
+ // configuration.
839
+ // Witness JVP and VJP are set below.
840
+ AutoDiffConfig silConfig (silParamIndices, config.resultIndices ,
841
+ config.derivativeGenericSignature );
842
+ SILDifferentiabilityWitnessKey key{originalFunction->getName (), silConfig};
843
+ auto *diffWitness = M.lookUpDifferentiabilityWitness (key);
844
+ if (!diffWitness) {
845
+ // Strip external from linkage of original function.
846
+ // Necessary for Clang-imported functions, which have external linkage.
847
+ auto linkage = stripExternalFromLinkage (originalFunction->getLinkage ());
848
+ diffWitness = SILDifferentiabilityWitness::createDefinition (
849
+ M, linkage, originalFunction, silConfig.parameterIndices ,
850
+ silConfig.resultIndices , config.derivativeGenericSignature ,
851
+ /* jvp*/ nullptr , /* vjp*/ nullptr ,
852
+ /* isSerialized*/ hasPublicVisibility (originalFunction->getLinkage ()),
853
+ attr);
854
+ }
855
+
856
+ // Set derivative function in differentiability witness.
857
+ auto setDerivativeInDifferentiabilityWitness =
858
+ [&](AutoDiffDerivativeFunctionKind kind, SILFunction *derivative) {
859
+ auto derivativeThunk = getOrCreateCustomDerivativeThunk (
860
+ derivative, originalFunction, silConfig, kind);
861
+ // Check for existing same derivative.
862
+ // TODO(TF-835): Remove condition below and simplify assertion to
863
+ // `!diffWitness->getDerivative(kind)` after `@derivative` attribute
864
+ // type-checking no longer generates implicit `@differentiable`
865
+ // attributes.
866
+ auto *existingDerivative = diffWitness->getDerivative (kind);
867
+ if (existingDerivative && existingDerivative == derivativeThunk)
868
+ return ;
869
+ assert (!existingDerivative &&
870
+ " SIL differentiability witness already has a different existing "
871
+ " derivative" );
872
+ diffWitness->setDerivative (kind, derivativeThunk);
873
+ };
874
+ if (jvp)
875
+ setDerivativeInDifferentiabilityWitness (AutoDiffDerivativeFunctionKind::JVP,
876
+ jvp);
877
+ if (vjp)
878
+ setDerivativeInDifferentiabilityWitness (AutoDiffDerivativeFunctionKind::VJP,
879
+ vjp);
754
880
}
755
881
756
882
void SILGenModule::
0 commit comments