@@ -3564,13 +3564,13 @@ resolveDifferentiableAttrOriginalFunction(DifferentiableAttr *attr) {
3564
3564
}
3565
3565
3566
3566
// / Given a `@differentiable` attribute, attempts to resolve the derivative
3567
- // / generic environment . The derivative generic environment is returned as
3568
- // / `derivativeGenEnv `. On error, emits diagnostic, assigns `nullptr` to
3569
- // / `derivativeGenEnv `, and returns true.
3570
- bool resolveDifferentiableAttrDerivativeGenericEnvironment (
3567
+ // / generic signature . The derivative generic signature is returned as
3568
+ // / `derivativeGenSig `. On error, emits diagnostic, assigns `nullptr` to
3569
+ // / `derivativeGenSig `, and returns true.
3570
+ bool resolveDifferentiableAttrDerivativeGenericSignature (
3571
3571
DifferentiableAttr *attr, AbstractFunctionDecl *original,
3572
- GenericEnvironment *&derivativeGenEnv ) {
3573
- derivativeGenEnv = nullptr ;
3572
+ GenericSignature &derivativeGenSig ) {
3573
+ derivativeGenSig = nullptr ;
3574
3574
3575
3575
auto &ctx = original->getASTContext ();
3576
3576
auto &diags = ctx.Diags ;
@@ -3584,7 +3584,7 @@ bool resolveDifferentiableAttrDerivativeGenericEnvironment(
3584
3584
// - If the `@differentiable` attribute has a `where` clause, use it to
3585
3585
// compute the derivative generic signature.
3586
3586
// - Otherwise, use the original function's generic signature by default.
3587
- auto derivativeGenSig = original->getGenericSignature ();
3587
+ derivativeGenSig = original->getGenericSignature ();
3588
3588
3589
3589
// Handle the `where` clause, if it exists.
3590
3590
// - Resolve attribute where clause requirements and store in the attribute
@@ -3663,10 +3663,9 @@ bool resolveDifferentiableAttrDerivativeGenericEnvironment(
3663
3663
return true ;
3664
3664
}
3665
3665
3666
- // Compute generic signature and environment for derivative functions.
3666
+ // Compute generic signature for derivative functions.
3667
3667
derivativeGenSig = std::move (builder).computeGenericSignature (
3668
3668
attr->getLocation (), /* allowConcreteGenericParams=*/ true );
3669
- derivativeGenEnv = derivativeGenSig->getGenericEnvironment ();
3670
3669
}
3671
3670
3672
3671
// Set the resolved derivative generic signature in the attribute.
@@ -3893,14 +3892,14 @@ llvm::Expected<IndexSubset *> DifferentiableAttributeTypeCheckRequest::evaluate(
3893
3892
}
3894
3893
}
3895
3894
3896
- // Resolve the derivative generic environment .
3897
- GenericEnvironment *derivativeGenEnv = nullptr ;
3898
- if (resolveDifferentiableAttrDerivativeGenericEnvironment (attr, original,
3899
- derivativeGenEnv ))
3895
+ // Resolve the derivative generic signature .
3896
+ GenericSignature derivativeGenSig = nullptr ;
3897
+ if (resolveDifferentiableAttrDerivativeGenericSignature (attr, original,
3898
+ derivativeGenSig ))
3900
3899
return nullptr ;
3901
- GenericSignature derivativeGenSig ;
3902
- if (derivativeGenEnv )
3903
- derivativeGenSig = derivativeGenEnv-> getGenericSignature ();
3900
+ GenericEnvironment *derivativeGenEnv = nullptr ;
3901
+ if (derivativeGenSig )
3902
+ derivativeGenEnv = derivativeGenSig-> getGenericEnvironment ();
3904
3903
3905
3904
// Compute the derivative function type.
3906
3905
auto derivativeFnTy = originalFnTy;
0 commit comments