@@ -191,36 +191,6 @@ SILFunctionType::getWitnessMethodClass(SILModule &M) const {
191
191
return nullptr ;
192
192
}
193
193
194
- // Returns the canonical generic signature for an autodiff derivative function
195
- // given an existing derivative function generic signature. All
196
- // differentiability parameters are required to conform to `Differentiable`.
197
- static CanGenericSignature getAutoDiffDerivativeFunctionGenericSignature (
198
- CanGenericSignature derivativeFnGenSig,
199
- ArrayRef<SILParameterInfo> originalParameters,
200
- IndexSubset *parameterIndices, ModuleDecl *module ) {
201
- if (!derivativeFnGenSig)
202
- return nullptr ;
203
- auto &ctx = module ->getASTContext ();
204
- GenericSignatureBuilder builder (ctx);
205
- // Add derivative function generic signature.
206
- builder.addGenericSignature (derivativeFnGenSig);
207
- // All differentiability parameters are required to conform to
208
- // `Differentiable`.
209
- auto source =
210
- GenericSignatureBuilder::FloatingRequirementSource::forAbstract ();
211
- auto *differentiableProtocol =
212
- ctx.getProtocol (KnownProtocolKind::Differentiable);
213
- for (unsigned paramIdx : parameterIndices->getIndices ()) {
214
- auto paramType = originalParameters[paramIdx].getInterfaceType ();
215
- Requirement req (RequirementKind::Conformance, paramType,
216
- differentiableProtocol->getDeclaredType ());
217
- builder.addRequirement (req, source, module );
218
- }
219
- return std::move (builder)
220
- .computeGenericSignature (SourceLoc (), /* allowConcreteGenericParams*/ true )
221
- ->getCanonicalSignature ();
222
- }
223
-
224
194
CanSILFunctionType SILFunctionType::getAutoDiffDerivativeFunctionType (
225
195
IndexSubset *parameterIndices, unsigned resultIndex,
226
196
AutoDiffDerivativeFunctionKind kind, TypeConverter &TC,
@@ -243,8 +213,8 @@ CanSILFunctionType SILFunctionType::getAutoDiffDerivativeFunctionType(
243
213
// Get the canonical derivative function generic signature.
244
214
if (!derivativeFnGenSig)
245
215
derivativeFnGenSig = getSubstGenericSignature ();
246
- derivativeFnGenSig = getAutoDiffDerivativeFunctionGenericSignature (
247
- derivativeFnGenSig, getParameters (), parameterIndices, &TC. M );
216
+ derivativeFnGenSig = autodiff::getConstrainedDerivativeGenericSignature (
217
+ this , parameterIndices, derivativeFnGenSig). getCanonicalSignature ( );
248
218
249
219
// Given a type, returns its formal SIL parameter info.
250
220
auto getTangentParameterInfoForOriginalResult =
0 commit comments