Skip to content

Commit 469ecb6

Browse files
committed
Improve @derivative type-checking diagnostics order.
Attempt to look up original function before checking whether the `value:` result conforms to `Differentiable`. This improves diagnostics: "original function not found" should be diagnosed as early as possible.
1 parent 2c70f0a commit 469ecb6

File tree

2 files changed

+30
-14
lines changed

2 files changed

+30
-14
lines changed

lib/Sema/TypeCheckAttr.cpp

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4054,19 +4054,6 @@ static bool typeCheckDerivativeAttr(ASTContext &Ctx, Decl *D,
40544054
return true;
40554055
}
40564056
attr->setDerivativeKind(kind);
4057-
// `value: R` result tuple element must conform to `Differentiable`.
4058-
auto diffableProto = Ctx.getProtocol(KnownProtocolKind::Differentiable);
4059-
auto valueResultType = valueResultElt.getType();
4060-
if (valueResultType->hasTypeParameter())
4061-
valueResultType = derivative->mapTypeIntoContext(valueResultType);
4062-
auto valueResultConf = TypeChecker::conformsToProtocol(
4063-
valueResultType, diffableProto, derivative->getDeclContext(), None);
4064-
if (!valueResultConf) {
4065-
diags.diagnose(attr->getLocation(),
4066-
diag::derivative_attr_result_value_not_differentiable,
4067-
valueResultElt.getType());
4068-
return true;
4069-
}
40704057

40714058
// Compute expected original function type and look up original function.
40724059
auto *originalFnType =
@@ -4221,6 +4208,7 @@ static bool typeCheckDerivativeAttr(ASTContext &Ctx, Decl *D,
42214208
diffParamTypes);
42224209

42234210
// Get the differentiability parameters' `TangentVector` associated types.
4211+
auto *diffableProto = Ctx.getProtocol(KnownProtocolKind::Differentiable);
42244212
auto diffParamTanTypes =
42254213
map<SmallVector<TupleTypeElt, 4>>(diffParamTypes, [&](Type paramType) {
42264214
if (paramType->hasTypeParameter())
@@ -4234,7 +4222,19 @@ static bool typeCheckDerivativeAttr(ASTContext &Ctx, Decl *D,
42344222
return TupleTypeElt(paramAssocType);
42354223
});
42364224

4225+
// `value: R` result tuple element must conform to `Differentiable`.
42374226
// Get the `TangentVector` associated type of the `value:` result type.
4227+
auto valueResultType = valueResultElt.getType();
4228+
if (valueResultType->hasTypeParameter())
4229+
valueResultType = derivative->mapTypeIntoContext(valueResultType);
4230+
auto valueResultConf = TypeChecker::conformsToProtocol(
4231+
valueResultType, diffableProto, derivative->getDeclContext(), None);
4232+
if (!valueResultConf) {
4233+
diags.diagnose(attr->getLocation(),
4234+
diag::derivative_attr_result_value_not_differentiable,
4235+
valueResultElt.getType());
4236+
return true;
4237+
}
42384238
auto resultTanType = valueResultConf.getTypeWitnessByName(
42394239
valueResultType, Ctx.Id_TangentVector);
42404240

test/AutoDiff/Sema/derivative_attr_type_checking.swift

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,22 @@ func vjpSubtractWrt1(x: Float, y: Float) -> (value: Float, pullback: (Float) ->
6161
return (x - y, { $0 })
6262
}
6363

64+
// Test invalid original function.
65+
66+
// expected-error @+1 {{use of unresolved identifier 'nonexistentFunction'}}
67+
@derivative(of: nonexistentFunction)
68+
func vjpOriginalFunctionNotFound(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
69+
fatalError()
70+
}
71+
72+
// Test `@derivative` attribute where `value:` result does not conform to `Differentiable`.
73+
// Invalid original function should be diagnosed first.
74+
// expected-error @+1 {{use of unresolved identifier 'nonexistentFunction'}}
75+
@derivative(of: nonexistentFunction)
76+
func vjpOriginalFunctionNotFound2(_ x: Float) -> (value: Int, pullback: (Float) -> Float) {
77+
fatalError()
78+
}
79+
6480
// Test incorrect `@derivative` declaration type.
6581

6682
// expected-note @+1 {{'incorrectDerivativeType' defined here}}
@@ -83,7 +99,7 @@ func vjpResultIncorrectFirstLabel(x: Float) -> (Float, (Float) -> Float) {
8399
func vjpResultIncorrectSecondLabel(x: Float) -> (value: Float, (Float) -> Float) {
84100
return (x, { $0 })
85101
}
86-
// expected-error @+1 {{'@derivative(of:)' attribute requires function to return a two-element tuple; first element type 'Int' must conform to 'Differentiable'}}
102+
// expected-error @+1 {{could not find function 'incorrectDerivativeType' with expected type '(Int) -> Int'}}
87103
@derivative(of: incorrectDerivativeType)
88104
func vjpResultNotDifferentiable(x: Int) -> (
89105
value: Int, pullback: (Int) -> Int

0 commit comments

Comments
 (0)