@@ -772,22 +772,22 @@ void SILGenModule::postEmitFunction(SILDeclRef constant,
772
772
jvp = getFunction (SILDeclRef (jvpDecl), NotForDefinition);
773
773
if (auto *vjpDecl = diffAttr->getVJPFunction ())
774
774
vjp = getFunction (SILDeclRef (vjpDecl), NotForDefinition);
775
- emitDifferentiabilityWitness (AFD, F, diffAttr->getParameterIndices (), jvp,
776
- vjp,
777
- diffAttr->getDerivativeGenericSignature ());
775
+ auto *resultIndices = IndexSubset::get (getASTContext (), 1 , {0 });
776
+ AutoDiffConfig config{diffAttr->getParameterIndices (), resultIndices,
777
+ diffAttr->getDerivativeGenericSignature ()};
778
+ emitDifferentiabilityWitness (AFD, F, config, jvp, vjp);
778
779
}
779
780
}
780
781
F->verify ();
781
782
}
782
783
783
784
void SILGenModule::emitDifferentiabilityWitness (
784
785
AbstractFunctionDecl *originalAFD, SILFunction *originalFunction,
785
- IndexSubset *parameterIndices, SILFunction *jvp, SILFunction *vjp,
786
- GenericSignature *derivativeGenSig) {
786
+ const AutoDiffConfig &config, SILFunction *jvp, SILFunction *vjp) {
787
787
auto *origFnType = originalAFD->getInterfaceType ()->castTo <AnyFunctionType>();
788
788
auto origSilFnType = originalFunction->getLoweredFunctionType ();
789
789
auto *loweredParamIndices = autodiff::getLoweredParameterIndices (
790
- parameterIndices, origFnType);
790
+ config. parameterIndices , origFnType);
791
791
// NOTE(TF-893): Extending capacity is necessary when `origSilFnType` has
792
792
// parameters corresponding to captured variables. These parameters do not
793
793
// appear in the type of `origFnType`.
@@ -813,7 +813,7 @@ void SILGenModule::emitDifferentiabilityWitness(
813
813
814
814
// Get or create differentiability witness.
815
815
CanGenericSignature derivativeCanGenSig;
816
- if (derivativeGenSig)
816
+ if (auto * derivativeGenSig = config. derivativeGenericSignature )
817
817
derivativeCanGenSig = derivativeGenSig->getCanonicalSignature ();
818
818
// TODO(TF-835): Use simpler derivative generic signature logic below when
819
819
// type-checking no longer generates implicit `@differentiable` attributes.
@@ -830,13 +830,12 @@ void SILGenModule::emitDifferentiabilityWitness(
830
830
derivativeCanGenSig = vjpCanGenSig;
831
831
assert (derivativeCanGenSig == vjpCanGenSig);
832
832
}
833
- auto *resultIndices = IndexSubset::get (getASTContext (), 1 , {0 });
834
833
// Create new SIL differentiability witness.
835
834
// Witness JVP and VJP are set below.
836
835
auto *diffWitness = SILDifferentiabilityWitness::create (
837
836
M, originalFunction->getLinkage (), originalFunction,
838
- loweredParamIndices, resultIndices, derivativeGenSig, /* jvp */ nullptr ,
839
- /* vjp*/ nullptr , /* isSerialized*/ true );
837
+ loweredParamIndices, config. resultIndices , derivativeCanGenSig ,
838
+ /* jvp */ nullptr , /* vjp*/ nullptr , /* isSerialized*/ true );
840
839
841
840
// Set derivative function in differentiability witness.
842
841
auto setDerivativeInDifferentiabilityWitness =
@@ -852,8 +851,10 @@ void SILGenModule::emitDifferentiabilityWitness(
852
851
derivativeThunk = getOrCreateAutoDiffDerivativeFunctionThunk (
853
852
originalFunction, indices, derivative, kind, reorderSelf);
854
853
} else {
854
+ // Note: `AutoDiffDerivativeFunctionIdentifier` must be constructed with
855
+ // the AST-level parameter indices, not the SIL-level ones.
855
856
auto *id = AutoDiffDerivativeFunctionIdentifier::get (
856
- kind, parameterIndices, getASTContext ());
857
+ kind, config. parameterIndices , getASTContext ());
857
858
derivativeThunk = getOrCreateAutoDiffThunk (
858
859
SILDeclRef (originalAFD).asAutoDiffDerivativeFunction (id), derivative,
859
860
expectedDerivativeType);
0 commit comments