Skip to content

Commit 5e030b5

Browse files
authored
Merge pull request #29918 from dan-zheng/derivative-attr-diagnostics
2 parents e5e9fce + 44d7ae6 commit 5e030b5

File tree

3 files changed

+67
-45
lines changed

3 files changed

+67
-45
lines changed

include/swift/AST/DiagnosticsSema.def

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2993,18 +2993,18 @@ NOTE(protocol_witness_missing_differentiable_attr,none,
29932993
// @derivative
29942994
ERROR(derivative_attr_expected_result_tuple,none,
29952995
"'@derivative(of:)' attribute requires function to return a two-element "
2996-
"tuple of type '(value: T..., pullback: (U.TangentVector) -> T.TangentVector...)' "
2997-
"or '(value: T..., differential: (T.TangentVector...) -> U.TangentVector)'", ())
2996+
"tuple; first element must have label 'value:' and second element must "
2997+
"have label 'pullback:' or 'differential:'", ())
29982998
ERROR(derivative_attr_invalid_result_tuple_value_label,none,
29992999
"'@derivative(of:)' attribute requires function to return a two-element "
3000-
"tuple (first element must have label 'value:')", ())
3000+
"tuple; first element must have label 'value:'", ())
30013001
ERROR(derivative_attr_invalid_result_tuple_func_label,none,
30023002
"'@derivative(of:)' attribute requires function to return a two-element "
3003-
"tuple (second element must have label 'pullback:' or 'differential:')",
3003+
"tuple; second element must have label 'pullback:' or 'differential:'",
30043004
())
30053005
ERROR(derivative_attr_result_value_not_differentiable,none,
30063006
"'@derivative(of:)' attribute requires function to return a two-element "
3007-
"tuple (first element type %0 must conform to 'Differentiable')", (Type))
3007+
"tuple; first element type %0 must conform to 'Differentiable'", (Type))
30083008
ERROR(derivative_attr_result_func_type_mismatch,none,
30093009
"function result's %0 type does not match %1", (Identifier, DeclName))
30103010
NOTE(derivative_attr_result_func_type_mismatch_note,none,

lib/Sema/TypeCheckAttr.cpp

Lines changed: 42 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -3822,6 +3822,31 @@ bool resolveDifferentiableAttrDerivativeFunctions(
38223822
return false;
38233823
}
38243824

3825+
/// Checks whether differentiable programming is enabled for the given
3826+
/// differentiation-related attribute. Returns true on error.
3827+
bool checkIfDifferentiableProgrammingEnabled(
3828+
ASTContext &ctx, DeclAttribute *attr) {
3829+
auto &diags = ctx.Diags;
3830+
// The experimental differentiable programming flag must be enabled.
3831+
if (!ctx.LangOpts.EnableExperimentalDifferentiableProgramming) {
3832+
diags
3833+
.diagnose(attr->getLocation(),
3834+
diag::experimental_differentiable_programming_disabled)
3835+
.highlight(attr->getRangeWithAt());
3836+
return true;
3837+
}
3838+
// The `Differentiable` protocol must be available.
3839+
// If unavailable, the `_Differentiation` module should be imported.
3840+
if (!ctx.getProtocol(KnownProtocolKind::Differentiable)) {
3841+
diags
3842+
.diagnose(attr->getLocation(), diag::attr_used_without_required_module,
3843+
attr, ctx.Id_Differentiation)
3844+
.highlight(attr->getRangeWithAt());
3845+
return true;
3846+
}
3847+
return false;
3848+
}
3849+
38253850
llvm::Expected<IndexSubset *> DifferentiableAttributeTypeCheckRequest::evaluate(
38263851
Evaluator &evaluator, DifferentiableAttr *attr) const {
38273852
// Skip type-checking for implicit `@differentiable` attributes. We currently
@@ -3838,21 +3863,8 @@ llvm::Expected<IndexSubset *> DifferentiableAttributeTypeCheckRequest::evaluate(
38383863
auto &diags = ctx.Diags;
38393864
// `@differentiable` attribute requires experimental differentiable
38403865
// programming to be enabled.
3841-
if (!ctx.LangOpts.EnableExperimentalDifferentiableProgramming) {
3842-
diags
3843-
.diagnose(attr->getLocation(),
3844-
diag::experimental_differentiable_programming_disabled)
3845-
.highlight(attr->getRangeWithAt());
3846-
return nullptr;
3847-
}
3848-
// The `Differentiable` protocol must be available.
3849-
if (!ctx.getProtocol(KnownProtocolKind::Differentiable)) {
3850-
diags
3851-
.diagnose(attr->getLocation(), diag::attr_used_without_required_module,
3852-
attr, ctx.Id_Differentiation)
3853-
.highlight(attr->getRangeWithAt());
3866+
if (checkIfDifferentiableProgrammingEnabled(ctx, attr))
38543867
return nullptr;
3855-
}
38563868

38573869
// Derivative registration is disabled for `@differentiable(linear)`
38583870
// attributes. Instead, use `@transpose` attribute to register transpose
@@ -4011,7 +4023,7 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) {
40114023
(void)attr->getParameterIndices();
40124024
}
40134025

4014-
/// Typechecks the given derivative attribute `attr` on decl `D`.
4026+
/// Type-checks the given `@derivative` attribute `attr` on declaration `D`.
40154027
///
40164028
/// Effects are:
40174029
/// - Sets the original function and parameter indices on `attr`.
@@ -4021,19 +4033,13 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) {
40214033
/// \returns true on error, false on success.
40224034
static bool typeCheckDerivativeAttr(ASTContext &Ctx, Decl *D,
40234035
DerivativeAttr *attr) {
4024-
// Note: Implementation must be idempotent because it can get called multiple
4036+
// Note: Implementation must be idempotent because it may be called multiple
40254037
// times for the same attribute.
4026-
40274038
auto &diags = Ctx.Diags;
4028-
40294039
// `@derivative` attribute requires experimental differentiable programming
40304040
// to be enabled.
4031-
auto &ctx = D->getASTContext();
4032-
if (!ctx.LangOpts.EnableExperimentalDifferentiableProgramming) {
4033-
diags.diagnose(attr->getLocation(),
4034-
diag::experimental_differentiable_programming_disabled);
4041+
if (checkIfDifferentiableProgrammingEnabled(Ctx, attr))
40354042
return true;
4036-
}
40374043
auto *derivative = cast<FuncDecl>(D);
40384044
auto lookupConformance =
40394045
LookUpConformanceInModule(D->getDeclContext()->getParentModule());
@@ -4075,19 +4081,6 @@ static bool typeCheckDerivativeAttr(ASTContext &Ctx, Decl *D,
40754081
return true;
40764082
}
40774083
attr->setDerivativeKind(kind);
4078-
// `value: R` result tuple element must conform to `Differentiable`.
4079-
auto diffableProto = Ctx.getProtocol(KnownProtocolKind::Differentiable);
4080-
auto valueResultType = valueResultElt.getType();
4081-
if (valueResultType->hasTypeParameter())
4082-
valueResultType = derivative->mapTypeIntoContext(valueResultType);
4083-
auto valueResultConf = TypeChecker::conformsToProtocol(
4084-
valueResultType, diffableProto, derivative->getDeclContext(), None);
4085-
if (!valueResultConf) {
4086-
diags.diagnose(attr->getLocation(),
4087-
diag::derivative_attr_result_value_not_differentiable,
4088-
valueResultElt.getType());
4089-
return true;
4090-
}
40914084

40924085
// Compute expected original function type and look up original function.
40934086
auto *originalFnType =
@@ -4274,6 +4267,7 @@ static bool typeCheckDerivativeAttr(ASTContext &Ctx, Decl *D,
42744267
diffParamTypes);
42754268

42764269
// Get the differentiability parameters' `TangentVector` associated types.
4270+
auto *diffableProto = Ctx.getProtocol(KnownProtocolKind::Differentiable);
42774271
auto diffParamTanTypes =
42784272
map<SmallVector<TupleTypeElt, 4>>(diffParamTypes, [&](Type paramType) {
42794273
if (paramType->hasTypeParameter())
@@ -4287,7 +4281,19 @@ static bool typeCheckDerivativeAttr(ASTContext &Ctx, Decl *D,
42874281
return TupleTypeElt(paramAssocType);
42884282
});
42894283

4284+
// `value: R` result tuple element must conform to `Differentiable`.
42904285
// Get the `TangentVector` associated type of the `value:` result type.
4286+
auto valueResultType = valueResultElt.getType();
4287+
if (valueResultType->hasTypeParameter())
4288+
valueResultType = derivative->mapTypeIntoContext(valueResultType);
4289+
auto valueResultConf = TypeChecker::conformsToProtocol(
4290+
valueResultType, diffableProto, derivative->getDeclContext(), None);
4291+
if (!valueResultConf) {
4292+
diags.diagnose(attr->getLocation(),
4293+
diag::derivative_attr_result_value_not_differentiable,
4294+
valueResultElt.getType());
4295+
return true;
4296+
}
42914297
auto resultTanType = valueResultConf.getTypeWitnessByName(
42924298
valueResultType, Ctx.Id_TangentVector);
42934299

test/AutoDiff/Sema/derivative_attr_type_checking.swift

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,29 +61,45 @@ func vjpSubtractWrt1(x: Float, y: Float) -> (value: Float, pullback: (Float) ->
6161
return (x - y, { $0 })
6262
}
6363

64+
// Test invalid original function.
65+
66+
// expected-error @+1 {{use of unresolved identifier 'nonexistentFunction'}}
67+
@derivative(of: nonexistentFunction)
68+
func vjpOriginalFunctionNotFound(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
69+
fatalError()
70+
}
71+
72+
// Test `@derivative` attribute where `value:` result does not conform to `Differentiable`.
73+
// Invalid original function should be diagnosed first.
74+
// expected-error @+1 {{use of unresolved identifier 'nonexistentFunction'}}
75+
@derivative(of: nonexistentFunction)
76+
func vjpOriginalFunctionNotFound2(_ x: Float) -> (value: Int, pullback: (Float) -> Float) {
77+
fatalError()
78+
}
79+
6480
// Test incorrect `@derivative` declaration type.
6581

6682
// expected-note @+1 {{'incorrectDerivativeType' defined here}}
6783
func incorrectDerivativeType(_ x: Float) -> Float {
6884
return x
6985
}
7086

71-
// expected-error @+1 {{'@derivative(of:)' attribute requires function to return a two-element tuple of type '(value: T..., pullback: (U.TangentVector) -> T.TangentVector...)' or '(value: T..., differential: (T.TangentVector...) -> U.TangentVector)'}}
87+
// expected-error @+1 {{'@derivative(of:)' attribute requires function to return a two-element tuple; first element must have label 'value:' and second element must have label 'pullback:' or 'differential:'}}
7288
@derivative(of: incorrectDerivativeType)
7389
func jvpResultIncorrect(x: Float) -> Float {
7490
return x
7591
}
76-
// expected-error @+1 {{'@derivative(of:)' attribute requires function to return a two-element tuple (first element must have label 'value:'}}
92+
// expected-error @+1 {{'@derivative(of:)' attribute requires function to return a two-element tuple; first element must have label 'value:'}}
7793
@derivative(of: incorrectDerivativeType)
7894
func vjpResultIncorrectFirstLabel(x: Float) -> (Float, (Float) -> Float) {
7995
return (x, { $0 })
8096
}
81-
// expected-error @+1 {{'@derivative(of:)' attribute requires function to return a two-element tuple (second element must have label 'pullback:' or 'differential:')}}
97+
// expected-error @+1 {{'@derivative(of:)' attribute requires function to return a two-element tuple; second element must have label 'pullback:' or 'differential:'}}
8298
@derivative(of: incorrectDerivativeType)
8399
func vjpResultIncorrectSecondLabel(x: Float) -> (value: Float, (Float) -> Float) {
84100
return (x, { $0 })
85101
}
86-
// expected-error @+1 {{'@derivative(of:)' attribute requires function to return a two-element tuple (first element type 'Int' must conform to 'Differentiable')}}
102+
// expected-error @+1 {{could not find function 'incorrectDerivativeType' with expected type '(Int) -> Int'}}
87103
@derivative(of: incorrectDerivativeType)
88104
func vjpResultNotDifferentiable(x: Int) -> (
89105
value: Int, pullback: (Int) -> Int

0 commit comments

Comments
 (0)