Skip to content

Commit b601662

Browse files
committed
Clean up @differentiable attribute type-checking.
1 parent 03593f4 commit b601662

File tree

1 file changed

+12
-13
lines changed

1 file changed

+12
-13
lines changed

lib/Sema/TypeCheckAttr.cpp

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3663,14 +3663,13 @@ resolveDifferentiableAttrOriginalFunction(DifferentiableAttr *attr) {
36633663
}
36643664

36653665
/// Given a `@differentiable` attribute, attempts to resolve the derivative
3666-
/// generic environment. The derivative generic environment is returned as
3667-
/// `derivativeGenEnv`. On error, emits diagnostic, assigns `nullptr` to
3668-
/// `derivativeGenEnv`, and returns true.
3669-
bool resolveDifferentiableAttrDerivativeGenericEnvironment(
3666+
/// generic signature. The derivative generic signature is returned as
3667+
/// `derivativeGenSig`. On error, emits diagnostic, assigns `nullptr` to
3668+
/// `derivativeGenSig`, and returns true.
3669+
bool resolveDifferentiableAttrDerivativeGenericSignature(
36703670
DifferentiableAttr *attr, AbstractFunctionDecl *original,
3671-
GenericSignature &derivativeGenSig, GenericEnvironment *&derivativeGenEnv) {
3671+
GenericSignature &derivativeGenSig) {
36723672
derivativeGenSig = nullptr;
3673-
derivativeGenEnv = nullptr;
36743673

36753674
auto &ctx = original->getASTContext();
36763675
auto &diags = ctx.Diags;
@@ -3763,10 +3762,9 @@ bool resolveDifferentiableAttrDerivativeGenericEnvironment(
37633762
return true;
37643763
}
37653764

3766-
// Compute generic signature and environment for derivative functions.
3765+
// Compute generic signature for derivative functions.
37673766
derivativeGenSig = std::move(builder).computeGenericSignature(
37683767
attr->getLocation(), /*allowConcreteGenericParams=*/true);
3769-
derivativeGenEnv = derivativeGenSig->getGenericEnvironment();
37703768
}
37713769

37723770
// Set the resolved derivative generic signature in the attribute.
@@ -3993,13 +3991,14 @@ llvm::Expected<IndexSubset *> DifferentiableAttributeTypeCheckRequest::evaluate(
39933991
}
39943992
}
39953993

3996-
// Resolve the derivative generic environment.
3994+
// Resolve the derivative generic signature.
39973995
GenericSignature derivativeGenSig;
3998-
GenericEnvironment *derivativeGenEnv = nullptr;
3999-
if (resolveDifferentiableAttrDerivativeGenericEnvironment(attr, original,
4000-
derivativeGenSig,
4001-
derivativeGenEnv))
3996+
if (resolveDifferentiableAttrDerivativeGenericSignature(attr, original,
3997+
derivativeGenSig))
40023998
return nullptr;
3999+
GenericEnvironment *derivativeGenEnv = nullptr;
4000+
if (derivativeGenSig)
4001+
derivativeGenEnv = derivativeGenSig->getGenericEnvironment();
40034002

40044003
// Compute the derivative function type.
40054004
auto derivativeFnTy = originalFnTy;

0 commit comments

Comments
 (0)