@@ -752,87 +752,140 @@ void SILGenModule::postEmitFunction(SILDeclRef constant,
752
752
F->print (llvm::dbgs ()));
753
753
754
754
// SWIFT_ENABLE_TENSORFLOW
755
- // Create self-reordering thunks for JVPs/VJPs of `@differentiable` methods.
756
- if (constant.hasDecl () && constant.getAbstractFunctionDecl ()) {
755
+ // Visit `@differentiable` and `@differentiating` attributes and generate SIL
756
+ // differentiability witnesses.
757
+ // Do not visit default argument generator functions.
758
+ if (constant.hasDecl () && constant.getAbstractFunctionDecl () &&
759
+ constant.kind != SILDeclRef::Kind::DefaultArgGenerator) {
757
760
auto *AFD = constant.getAbstractFunctionDecl ();
758
- auto origFnType = AFD->getInterfaceType ()->castTo <AnyFunctionType>();
759
- auto origSilFnType = F->getLoweredFunctionType ();
760
- // Jointly iterate over AST `@differentiable` attributes and SIL
761
- // `[differentiable]` attributes.
762
- auto diffAttrs = AFD->getAttrs ().getAttributes <DifferentiableAttr>();
763
- auto silDiffAttrs = F->getDifferentiableAttrs ();
764
- for (auto pair : llvm::zip (diffAttrs, silDiffAttrs)) {
765
- auto *diffAttr = const_cast <DifferentiableAttr *>(std::get<0 >(pair));
766
- auto *silDiffAttr = std::get<1 >(pair);
767
- // Compute lowered parameter indices.
768
- auto *paramIndices = diffAttr->getParameterIndices ();
769
- auto *loweredParamIndices = autodiff::getLoweredParameterIndices (
770
- paramIndices, origFnType);
771
- SILAutoDiffIndices indices (/* source*/ 0 , loweredParamIndices);
772
- assert (silDiffAttr->getIndices () == indices &&
773
- " Expected matching @differentiable and [differentiable] indices" );
774
-
775
- auto lookUpConformance = LookUpConformanceInModule (M.getSwiftModule ());
776
- auto expectedJVPType = origSilFnType->getAutoDiffDerivativeFunctionType (
777
- indices.parameters , indices.source ,
778
- AutoDiffDerivativeFunctionKind::JVP, Types, lookUpConformance);
779
- auto expectedVJPType = origSilFnType->getAutoDiffDerivativeFunctionType (
780
- indices.parameters , indices.source ,
781
- AutoDiffDerivativeFunctionKind::VJP, Types, lookUpConformance);
782
-
783
- // Self reordering is necessary if wrt at least two parameters, including
784
- // self.
785
- auto shouldReorderSelf = [&]() {
786
- if (!F->hasSelfParam ())
787
- return false ;
788
- auto selfParamIndex = origSilFnType->getNumParameters () - 1 ;
789
- if (!indices.isWrtParameter (selfParamIndex))
790
- return false ;
791
- return indices.parameters ->getNumIndices () > 1 ;
792
- };
793
- bool reorderSelf = shouldReorderSelf ();
794
-
795
- // Thunk JVP method, if it is defined.
796
- if (auto *jvpDecl = diffAttr->getJVPFunction ()) {
797
- SILFunction *jvpThunk;
798
- auto *jvpFn = getFunction (SILDeclRef (jvpDecl), NotForDefinition);
799
- if (reorderSelf || jvpFn->getLoweredFunctionType () != expectedJVPType) {
800
- jvpThunk = getOrCreateAutoDiffDerivativeFunctionThunk (
801
- F, indices, jvpFn, AutoDiffDerivativeFunctionKind::JVP,
802
- reorderSelf);
803
- } else {
804
- auto *id = AutoDiffDerivativeFunctionIdentifier::get (
805
- AutoDiffDerivativeFunctionKind::JVP,
806
- diffAttr->getParameterIndices (), AFD->getASTContext ());
807
- jvpThunk = getOrCreateAutoDiffThunk (
808
- constant.asAutoDiffDerivativeFunction (id), jvpFn,
809
- expectedJVPType);
810
- }
811
- silDiffAttr->setJVPName (jvpThunk->getName ());
812
- }
813
- // Thunk VJP method, if it is defined.
814
- if (auto *vjpDecl = diffAttr->getVJPFunction ()) {
815
- SILFunction *vjpThunk;
816
- auto *vjpFn = getFunction (SILDeclRef (vjpDecl), NotForDefinition);
817
- if (reorderSelf || vjpFn->getLoweredFunctionType () != expectedVJPType) {
818
- vjpThunk = getOrCreateAutoDiffDerivativeFunctionThunk (
819
- F, indices, vjpFn, AutoDiffDerivativeFunctionKind::VJP,
820
- reorderSelf);
821
- } else {
822
- auto *id = AutoDiffDerivativeFunctionIdentifier::get (
823
- AutoDiffDerivativeFunctionKind::VJP,
824
- diffAttr->getParameterIndices (), AFD->getASTContext ());
825
- vjpThunk = getOrCreateAutoDiffThunk (
826
- constant.asAutoDiffDerivativeFunction (id), vjpFn,
827
- expectedVJPType);
828
- }
829
- silDiffAttr->setVJPName (vjpThunk->getName ());
761
+ // Visit all `@differentiable` attributes.
762
+ for (auto *diffAttr : AFD->getAttrs ().getAttributes <DifferentiableAttr>()) {
763
+ SILFunction *jvp = nullptr ;
764
+ SILFunction *vjp = nullptr ;
765
+ if (auto *jvpDecl = diffAttr->getJVPFunction ())
766
+ jvp = getFunction (SILDeclRef (jvpDecl), NotForDefinition);
767
+ if (auto *vjpDecl = diffAttr->getVJPFunction ())
768
+ vjp = getFunction (SILDeclRef (vjpDecl), NotForDefinition);
769
+ emitDifferentiabilityWitness (AFD, F, diffAttr->getParameterIndices (), jvp,
770
+ vjp);
771
+ }
772
+ // Visit all `@differentiating` attributes.
773
+ for (auto *diffAttr :
774
+ AFD->getAttrs ().getAttributes <DifferentiatingAttr>()) {
775
+ auto *origAFD = diffAttr->getOriginalFunction ();
776
+ auto *origFn = getFunction (SILDeclRef (origAFD), NotForDefinition);
777
+ SILFunction *jvp = nullptr ;
778
+ SILFunction *vjp = nullptr ;
779
+ switch (diffAttr->getDerivativeKind ()) {
780
+ case AutoDiffDerivativeFunctionKind::JVP:
781
+ jvp = F;
782
+ break ;
783
+ case AutoDiffDerivativeFunctionKind::VJP:
784
+ vjp = F;
785
+ break ;
830
786
}
787
+ emitDifferentiabilityWitness (origAFD, origFn,
788
+ diffAttr->getParameterIndices (), jvp, vjp);
831
789
}
832
790
}
833
791
F->verify ();
834
792
}
835
793
794
+ void SILGenModule::emitDifferentiabilityWitness (
795
+ AbstractFunctionDecl *originalAFD, SILFunction *originalFunction,
796
+ IndexSubset *parameterIndices, SILFunction *jvp, SILFunction *vjp) {
797
+ auto *origFnType = originalAFD->getInterfaceType ()->castTo <AnyFunctionType>();
798
+ auto origSilFnType = originalFunction->getLoweredFunctionType ();
799
+ auto *loweredParamIndices = autodiff::getLoweredParameterIndices (
800
+ parameterIndices, origFnType);
801
+ // NOTE(TF-893): Extending capacity is necessary when `origSilFnType` has
802
+ // parameters corresponding to captured variables. These parameters do not
803
+ // appear in the type of `origFnType`.
804
+ // TODO: If posssible, change `autodiff::getLoweredParameterIndices` to
805
+ // take `CaptureInfo` into account.
806
+ if (origSilFnType->getNumParameters () > loweredParamIndices->getCapacity ())
807
+ loweredParamIndices = loweredParamIndices->extendingCapacity (
808
+ getASTContext (), origSilFnType->getNumParameters ());
809
+ SILAutoDiffIndices indices (/* source*/ 0 , loweredParamIndices);
810
+
811
+ // Self reordering thunk is necessary if wrt at least two parameters,
812
+ // including self.
813
+ auto shouldReorderSelf = [&]() {
814
+ if (!originalFunction->hasSelfParam ())
815
+ return false ;
816
+ auto selfParamIndex = origSilFnType->getNumParameters () - 1 ;
817
+ if (!indices.isWrtParameter (selfParamIndex))
818
+ return false ;
819
+ return indices.parameters ->getNumIndices () > 1 ;
820
+ };
821
+ bool reorderSelf = shouldReorderSelf ();
822
+
823
+ // Get or create differentiability witness.
824
+ CanGenericSignature derivativeGenSig;
825
+ if (jvp && vjp)
826
+ assert (jvp->getLoweredFunctionType ()->getGenericSignature () ==
827
+ vjp->getLoweredFunctionType ()->getGenericSignature () &&
828
+ " JVP and VJP generic signatures must match" );
829
+ if (jvp)
830
+ derivativeGenSig = jvp->getLoweredFunctionType ()->getGenericSignature ();
831
+ if (vjp)
832
+ derivativeGenSig = vjp->getLoweredFunctionType ()->getGenericSignature ();
833
+ auto *resultIndices = IndexSubset::get (getASTContext (), 1 , {0 });
834
+ AutoDiffConfig config{loweredParamIndices, resultIndices,
835
+ derivativeGenSig};
836
+ auto key = std::make_pair (originalFunction->getName (), config);
837
+ SILDifferentiabilityWitness *diffWitness = nullptr ;
838
+ if (auto *foundWitness = M.lookUpDifferentiabilityWitness (
839
+ key, /* deserializeLazily*/ false )) {
840
+ diffWitness = foundWitness;
841
+ } else {
842
+ // Create new SIL differentiability witness.
843
+ diffWitness = SILDifferentiabilityWitness::create (
844
+ M, originalFunction->getLinkage (), originalFunction,
845
+ loweredParamIndices, resultIndices, derivativeGenSig, /* jvp*/ nullptr ,
846
+ /* vjp*/ nullptr , /* isSerialized*/ true );
847
+ }
848
+
849
+ // Set derivative function in differentiability witness.
850
+ auto setDerivativeInDifferentiabilityWitness =
851
+ [&](AutoDiffDerivativeFunctionKind kind, SILFunction *derivative) {
852
+ auto expectedDerivativeType =
853
+ origSilFnType->getAutoDiffDerivativeFunctionType (
854
+ indices.parameters , indices.source , kind, Types,
855
+ LookUpConformanceInModule (M.getSwiftModule ()));
856
+ // Thunk derivative function.
857
+ SILFunction *derivativeThunk;
858
+ if (reorderSelf ||
859
+ derivative->getLoweredFunctionType () != expectedDerivativeType) {
860
+ derivativeThunk = getOrCreateAutoDiffDerivativeFunctionThunk (
861
+ originalFunction, indices, derivative, kind, reorderSelf);
862
+ } else {
863
+ auto *id = AutoDiffDerivativeFunctionIdentifier::get (
864
+ kind, parameterIndices, getASTContext ());
865
+ derivativeThunk = getOrCreateAutoDiffThunk (
866
+ SILDeclRef (originalAFD).asAutoDiffDerivativeFunction (id), derivative,
867
+ expectedDerivativeType);
868
+ }
869
+ // Check for existing same derivative.
870
+ // TODO(TF-898): Remove condition below and simplify assertion to
871
+ // `!diffWitness->getDerivative(kind)` after `@differentiating` attribute
872
+ // type-checking no longer generates implicit `@differentiable` attributes.
873
+ auto *existingDerivative = diffWitness->getDerivative (kind);
874
+ if (existingDerivative && existingDerivative == derivativeThunk)
875
+ return ;
876
+ assert (!existingDerivative &&
877
+ " SIL differentiability witness already has a different existing "
878
+ " derivative" );
879
+ diffWitness->setDerivative (kind, derivativeThunk);
880
+ };
881
+ if (jvp)
882
+ setDerivativeInDifferentiabilityWitness (AutoDiffDerivativeFunctionKind::JVP,
883
+ jvp);
884
+ if (vjp)
885
+ setDerivativeInDifferentiabilityWitness (AutoDiffDerivativeFunctionKind::VJP,
886
+ vjp);
887
+ }
888
+
836
889
void SILGenModule::
837
890
emitMarkFunctionEscapeForTopLevelCodeGlobals (SILLocation loc,
838
891
const CaptureInfo &captureInfo) {
0 commit comments