Skip to content

Commit 9274e86

Browse files
authored
Merge pull request #29477 from dan-zheng/autodiff-upstream-diff-attr
2 parents 509ffb3 + b188b90 commit 9274e86

File tree

1 file changed

+15
-16
lines changed

1 file changed

+15
-16
lines changed

lib/Sema/TypeCheckAttr.cpp

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3564,13 +3564,13 @@ resolveDifferentiableAttrOriginalFunction(DifferentiableAttr *attr) {
35643564
}
35653565

35663566
/// Given a `@differentiable` attribute, attempts to resolve the derivative
3567-
/// generic environment. The derivative generic environment is returned as
3568-
/// `derivativeGenEnv`. On error, emits diagnostic, assigns `nullptr` to
3569-
/// `derivativeGenEnv`, and returns true.
3570-
bool resolveDifferentiableAttrDerivativeGenericEnvironment(
3567+
/// generic signature. The derivative generic signature is returned as
3568+
/// `derivativeGenSig`. On error, emits diagnostic, assigns `nullptr` to
3569+
/// `derivativeGenSig`, and returns true.
3570+
bool resolveDifferentiableAttrDerivativeGenericSignature(
35713571
DifferentiableAttr *attr, AbstractFunctionDecl *original,
3572-
GenericEnvironment *&derivativeGenEnv) {
3573-
derivativeGenEnv = nullptr;
3572+
GenericSignature &derivativeGenSig) {
3573+
derivativeGenSig = nullptr;
35743574

35753575
auto &ctx = original->getASTContext();
35763576
auto &diags = ctx.Diags;
@@ -3584,7 +3584,7 @@ bool resolveDifferentiableAttrDerivativeGenericEnvironment(
35843584
// - If the `@differentiable` attribute has a `where` clause, use it to
35853585
// compute the derivative generic signature.
35863586
// - Otherwise, use the original function's generic signature by default.
3587-
auto derivativeGenSig = original->getGenericSignature();
3587+
derivativeGenSig = original->getGenericSignature();
35883588

35893589
// Handle the `where` clause, if it exists.
35903590
// - Resolve attribute where clause requirements and store in the attribute
@@ -3663,10 +3663,9 @@ bool resolveDifferentiableAttrDerivativeGenericEnvironment(
36633663
return true;
36643664
}
36653665

3666-
// Compute generic signature and environment for derivative functions.
3666+
// Compute generic signature for derivative functions.
36673667
derivativeGenSig = std::move(builder).computeGenericSignature(
36683668
attr->getLocation(), /*allowConcreteGenericParams=*/true);
3669-
derivativeGenEnv = derivativeGenSig->getGenericEnvironment();
36703669
}
36713670

36723671
// Set the resolved derivative generic signature in the attribute.
@@ -3893,14 +3892,14 @@ llvm::Expected<IndexSubset *> DifferentiableAttributeTypeCheckRequest::evaluate(
38933892
}
38943893
}
38953894

3896-
// Resolve the derivative generic environment.
3897-
GenericEnvironment *derivativeGenEnv = nullptr;
3898-
if (resolveDifferentiableAttrDerivativeGenericEnvironment(attr, original,
3899-
derivativeGenEnv))
3895+
// Resolve the derivative generic signature.
3896+
GenericSignature derivativeGenSig = nullptr;
3897+
if (resolveDifferentiableAttrDerivativeGenericSignature(attr, original,
3898+
derivativeGenSig))
39003899
return nullptr;
3901-
GenericSignature derivativeGenSig;
3902-
if (derivativeGenEnv)
3903-
derivativeGenSig = derivativeGenEnv->getGenericSignature();
3900+
GenericEnvironment *derivativeGenEnv = nullptr;
3901+
if (derivativeGenSig)
3902+
derivativeGenEnv = derivativeGenSig->getGenericEnvironment();
39043903

39053904
// Compute the derivative function type.
39063905
auto derivativeFnTy = originalFnTy;

0 commit comments

Comments
 (0)