@@ -3808,6 +3808,31 @@ bool resolveDifferentiableAttrDerivativeFunctions(
3808
3808
return false ;
3809
3809
}
3810
3810
3811
+ // / Checks whether differentiable programming is enabled for the given
3812
+ // / differentiation-related attribute. Returns true on error.
3813
+ bool checkIfDifferentiableProgrammingEnabled (
3814
+ ASTContext &ctx, DeclAttribute *attr) {
3815
+ auto &diags = ctx.Diags ;
3816
+ // The experimental differentiable programming flag must be enabled.
3817
+ if (!ctx.LangOpts .EnableExperimentalDifferentiableProgramming ) {
3818
+ diags
3819
+ .diagnose (attr->getLocation (),
3820
+ diag::experimental_differentiable_programming_disabled)
3821
+ .highlight (attr->getRangeWithAt ());
3822
+ return true ;
3823
+ }
3824
+ // The `Differentiable` protocol must be available.
3825
+ // If unavailable, the `_Differentiation` module should be imported.
3826
+ if (!ctx.getProtocol (KnownProtocolKind::Differentiable)) {
3827
+ diags
3828
+ .diagnose (attr->getLocation (), diag::attr_used_without_required_module,
3829
+ attr, ctx.Id_Differentiation )
3830
+ .highlight (attr->getRangeWithAt ());
3831
+ return true ;
3832
+ }
3833
+ return false ;
3834
+ }
3835
+
3811
3836
llvm::Expected<IndexSubset *> DifferentiableAttributeTypeCheckRequest::evaluate (
3812
3837
Evaluator &evaluator, DifferentiableAttr *attr) const {
3813
3838
// Skip type-checking for implicit `@differentiable` attributes. We currently
@@ -3824,21 +3849,8 @@ llvm::Expected<IndexSubset *> DifferentiableAttributeTypeCheckRequest::evaluate(
3824
3849
auto &diags = ctx.Diags ;
3825
3850
// `@differentiable` attribute requires experimental differentiable
3826
3851
// programming to be enabled.
3827
- if (!ctx.LangOpts .EnableExperimentalDifferentiableProgramming ) {
3828
- diags
3829
- .diagnose (attr->getLocation (),
3830
- diag::experimental_differentiable_programming_disabled)
3831
- .highlight (attr->getRangeWithAt ());
3852
+ if (checkIfDifferentiableProgrammingEnabled (ctx, attr))
3832
3853
return nullptr ;
3833
- }
3834
- // The `Differentiable` protocol must be available.
3835
- if (!ctx.getProtocol (KnownProtocolKind::Differentiable)) {
3836
- diags
3837
- .diagnose (attr->getLocation (), diag::attr_used_without_required_module,
3838
- attr, ctx.Id_Differentiation )
3839
- .highlight (attr->getRangeWithAt ());
3840
- return nullptr ;
3841
- }
3842
3854
3843
3855
// Derivative registration is disabled for `@differentiable(linear)`
3844
3856
// attributes. Instead, use `@transpose` attribute to register transpose
@@ -3990,7 +4002,7 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) {
3990
4002
(void )attr->getParameterIndices ();
3991
4003
}
3992
4004
3993
- // / Typechecks the given derivative attribute `attr` on decl `D`.
4005
+ // / Type-checks the given `@ derivative` attribute `attr` on declaration `D`.
3994
4006
// /
3995
4007
// / Effects are:
3996
4008
// / - Sets the original function and parameter indices on `attr`.
@@ -4000,19 +4012,13 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) {
4000
4012
// / \returns true on error, false on success.
4001
4013
static bool typeCheckDerivativeAttr (ASTContext &Ctx, Decl *D,
4002
4014
DerivativeAttr *attr) {
4003
- // Note: Implementation must be idempotent because it can get called multiple
4015
+ // Note: Implementation must be idempotent because it may be called multiple
4004
4016
// times for the same attribute.
4005
-
4006
4017
auto &diags = Ctx.Diags ;
4007
-
4008
4018
// `@derivative` attribute requires experimental differentiable programming
4009
4019
// to be enabled.
4010
- auto &ctx = D->getASTContext ();
4011
- if (!ctx.LangOpts .EnableExperimentalDifferentiableProgramming ) {
4012
- diags.diagnose (attr->getLocation (),
4013
- diag::experimental_differentiable_programming_disabled);
4020
+ if (checkIfDifferentiableProgrammingEnabled (Ctx, attr))
4014
4021
return true ;
4015
- }
4016
4022
auto *derivative = cast<FuncDecl>(D);
4017
4023
auto lookupConformance =
4018
4024
LookUpConformanceInModule (D->getDeclContext ()->getParentModule ());
0 commit comments