Skip to content

Commit 865c4ea

Browse files
committed
Gate @derivative attribute by differentiable programming flag.
Gate `@derivative` attribute by `-enable-experimental-differentiable-programming` flag.
1 parent 60b11ff commit 865c4ea

File tree

2 files changed

+18
-1
lines changed

2 files changed

+18
-1
lines changed

lib/Sema/TypeCheckAttr.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3354,7 +3354,15 @@ getAutoDiffOriginalFunctionType(AnyFunctionType *derivativeFnTy) {
33543354
}
33553355

33563356
void AttributeChecker::visitDerivativeAttr(DerivativeAttr *attr) {
3357-
FuncDecl *derivative = cast<FuncDecl>(D);
3357+
// `@derivative` attribute requires experimental differentiable programming
3358+
// to be enabled.
3359+
auto &ctx = D->getASTContext();
3360+
if (!ctx.LangOpts.EnableExperimentalDifferentiableProgramming) {
3361+
diagnoseAndRemoveAttr(
3362+
attr, diag::experimental_differentiable_programming_disabled);
3363+
return;
3364+
}
3365+
auto *derivative = cast<FuncDecl>(D);
33583366
auto lookupConformance =
33593367
LookUpConformanceInModule(D->getDeclContext()->getParentModule());
33603368
auto originalName = attr->getOriginalFunctionName();

test/AutoDiff/Sema/differentiable_features_disabled.swift

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,12 @@
22

33
// expected-error @+1 {{differentiable programming is an experimental feature that is currently disabled}}
44
let _: @differentiable (Float) -> Float
5+
6+
func id(_ x: Float) -> Float {
7+
return x
8+
}
9+
// expected-error @+1 {{differentiable programming is an experimental feature that is currently disabled}}
10+
@derivative(of: id)
11+
func jvpId(x: Float) -> (value: Float, differential: (Float) -> (Float)) {
12+
return (x, { $0 })
13+
}

0 commit comments

Comments
 (0)