Skip to content

Commit 44d7ae6

Browse files
committed
Add utility for checking whether differentiable programming is enabled.
1 parent 469ecb6 commit 44d7ae6

File tree

1 file changed

+29
-23
lines changed

1 file changed

+29
-23
lines changed

lib/Sema/TypeCheckAttr.cpp

Lines changed: 29 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3808,6 +3808,31 @@ bool resolveDifferentiableAttrDerivativeFunctions(
38083808
return false;
38093809
}
38103810

3811+
/// Checks whether differentiable programming is enabled for the given
3812+
/// differentiation-related attribute. Returns true on error.
3813+
bool checkIfDifferentiableProgrammingEnabled(
3814+
ASTContext &ctx, DeclAttribute *attr) {
3815+
auto &diags = ctx.Diags;
3816+
// The experimental differentiable programming flag must be enabled.
3817+
if (!ctx.LangOpts.EnableExperimentalDifferentiableProgramming) {
3818+
diags
3819+
.diagnose(attr->getLocation(),
3820+
diag::experimental_differentiable_programming_disabled)
3821+
.highlight(attr->getRangeWithAt());
3822+
return true;
3823+
}
3824+
// The `Differentiable` protocol must be available.
3825+
// If unavailable, the `_Differentiation` module should be imported.
3826+
if (!ctx.getProtocol(KnownProtocolKind::Differentiable)) {
3827+
diags
3828+
.diagnose(attr->getLocation(), diag::attr_used_without_required_module,
3829+
attr, ctx.Id_Differentiation)
3830+
.highlight(attr->getRangeWithAt());
3831+
return true;
3832+
}
3833+
return false;
3834+
}
3835+
38113836
llvm::Expected<IndexSubset *> DifferentiableAttributeTypeCheckRequest::evaluate(
38123837
Evaluator &evaluator, DifferentiableAttr *attr) const {
38133838
// Skip type-checking for implicit `@differentiable` attributes. We currently
@@ -3824,21 +3849,8 @@ llvm::Expected<IndexSubset *> DifferentiableAttributeTypeCheckRequest::evaluate(
38243849
auto &diags = ctx.Diags;
38253850
// `@differentiable` attribute requires experimental differentiable
38263851
// programming to be enabled.
3827-
if (!ctx.LangOpts.EnableExperimentalDifferentiableProgramming) {
3828-
diags
3829-
.diagnose(attr->getLocation(),
3830-
diag::experimental_differentiable_programming_disabled)
3831-
.highlight(attr->getRangeWithAt());
3852+
if (checkIfDifferentiableProgrammingEnabled(ctx, attr))
38323853
return nullptr;
3833-
}
3834-
// The `Differentiable` protocol must be available.
3835-
if (!ctx.getProtocol(KnownProtocolKind::Differentiable)) {
3836-
diags
3837-
.diagnose(attr->getLocation(), diag::attr_used_without_required_module,
3838-
attr, ctx.Id_Differentiation)
3839-
.highlight(attr->getRangeWithAt());
3840-
return nullptr;
3841-
}
38423854

38433855
// Derivative registration is disabled for `@differentiable(linear)`
38443856
// attributes. Instead, use `@transpose` attribute to register transpose
@@ -3990,7 +4002,7 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) {
39904002
(void)attr->getParameterIndices();
39914003
}
39924004

3993-
/// Typechecks the given derivative attribute `attr` on decl `D`.
4005+
/// Type-checks the given `@derivative` attribute `attr` on declaration `D`.
39944006
///
39954007
/// Effects are:
39964008
/// - Sets the original function and parameter indices on `attr`.
@@ -4000,19 +4012,13 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) {
40004012
/// \returns true on error, false on success.
40014013
static bool typeCheckDerivativeAttr(ASTContext &Ctx, Decl *D,
40024014
DerivativeAttr *attr) {
4003-
// Note: Implementation must be idempotent because it can get called multiple
4015+
// Note: Implementation must be idempotent because it may be called multiple
40044016
// times for the same attribute.
4005-
40064017
auto &diags = Ctx.Diags;
4007-
40084018
// `@derivative` attribute requires experimental differentiable programming
40094019
// to be enabled.
4010-
auto &ctx = D->getASTContext();
4011-
if (!ctx.LangOpts.EnableExperimentalDifferentiableProgramming) {
4012-
diags.diagnose(attr->getLocation(),
4013-
diag::experimental_differentiable_programming_disabled);
4020+
if (checkIfDifferentiableProgrammingEnabled(Ctx, attr))
40144021
return true;
4015-
}
40164022
auto *derivative = cast<FuncDecl>(D);
40174023
auto lookupConformance =
40184024
LookUpConformanceInModule(D->getDeclContext()->getParentModule());

0 commit comments

Comments
 (0)