Skip to content

[AutoDiff] Improve @derivative attribute diagnostics. #29918

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 20, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions include/swift/AST/DiagnosticsSema.def
Original file line number Diff line number Diff line change
Expand Up @@ -2972,18 +2972,18 @@ NOTE(protocol_witness_missing_differentiable_attr,none,
// @derivative
ERROR(derivative_attr_expected_result_tuple,none,
"'@derivative(of:)' attribute requires function to return a two-element "
"tuple of type '(value: T..., pullback: (U.TangentVector) -> T.TangentVector...)' "
"or '(value: T..., differential: (T.TangentVector...) -> U.TangentVector)'", ())
"tuple; first element must have label 'value:' and second element must "
"have label 'pullback:' or 'differential:'", ())
ERROR(derivative_attr_invalid_result_tuple_value_label,none,
"'@derivative(of:)' attribute requires function to return a two-element "
"tuple (first element must have label 'value:')", ())
"tuple; first element must have label 'value:'", ())
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: derivative_attr_expected_result_tuple (1) is somewhat duplicated by derivative_attr_invalid_result_tuple_value_label (2) and derivative_attr_invalid_result_tuple_func_label (3).

(1) is produced when the @derivative function type doesn't return a two-element tuple. (2) and (3) are produced when the returned two-element tuple's labels are incorrect.

Unless someone has suggestions, I'm content to leave (2) and (3) as is, since they're more specific than (1).

ERROR(derivative_attr_invalid_result_tuple_func_label,none,
"'@derivative(of:)' attribute requires function to return a two-element "
"tuple (second element must have label 'pullback:' or 'differential:')",
"tuple; second element must have label 'pullback:' or 'differential:'",
())
ERROR(derivative_attr_result_value_not_differentiable,none,
"'@derivative(of:)' attribute requires function to return a two-element "
"tuple (first element type %0 must conform to 'Differentiable')", (Type))
"tuple; first element type %0 must conform to 'Differentiable'", (Type))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why's "conform to Differentiable" special here in a derivative function? Would it be better to just emit a diagnostic that suggests the expected derivative type?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be better to just emit a diagnostic that suggests the expected derivative type?

That would be ideal. However, for @derivative attribute type-checking, the "expected derivative type" is not known - we start with the type of the @derivative declaration and try to compute the appropriate original function type.

Copy link
Contributor

@rxwei rxwei Feb 19, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, but I feel the "must conform to Differentiable" error message is making it worse because it seems to suggest that derivative function's value: result is somehow special with regards to Differentiable conformance, while it's in fact simply the same as the original function's result. Just my two cents. I think a better diagnostic could be:

'value:' element type %0 must be the same as the result of the original function being differentiated

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand your point. I pushed some further diagnostics improvements along this direction.

Now, original function lookup occurs before the "does first element type conform to Differentiable?" check. In effect, this delays "find Differentiable conformance for value: result" as late as possible, so original function lookup diagnostics trigger first.


Example:

import _Differentiation
@derivative(of: nonexistentFunction)
func derivative(_ x: Float) -> (value: Int, pullback: (Float) -> Float) {
  fatalError()
}

Before: unideal diagnostic about value: result.

derivative.swift:2:2: error: '@derivative(of:)' attribute requires function to return a two-element tuple (first element type 'Int' must conform to 'Differentiable')
@derivative(of: nonexistentFunction)
 ^

After: ideal diagnostic about original function not found.

derivative.swift:2:17: error: use of unresolved identifier 'nonexistentFunction'
@derivative(of: nonexistentFunction)
                ^

Example:

import _Differentiation

func original(_ x: Int) -> Int { x }

@derivative(of: original)
func vjpOriginalFunctionNotFound2(_ x: Float) -> (value: Int, pullback: (Float) -> Float) {
  fatalError()
}

Before: unideal diagnostic about value: result.

derivative2.swift:5:2: error: '@derivative(of:)' attribute requires function to return a two-element tuple (first element type 'Int' must conform to 'Differentiable')
@derivative(of: original)
 ^

After: ideal diagnostic about original function not found.

derivative2.swift:5:17: error: could not find function 'original' with expected type '(Float) -> Int'
@derivative(of: original)
                ^

Let me know if these changes make sense, and if you'd like further changes! I'd like to defer major diagnostic changes until later to unblock progress.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is very nice!

ERROR(derivative_attr_result_func_type_mismatch,none,
"function result's %0 type does not match %1", (Identifier, DeclName))
NOTE(derivative_attr_result_func_type_mismatch_note,none,
Expand Down
78 changes: 42 additions & 36 deletions lib/Sema/TypeCheckAttr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3808,6 +3808,31 @@ bool resolveDifferentiableAttrDerivativeFunctions(
return false;
}

/// Checks whether differentiable programming is enabled for the given
/// differentiation-related attribute. Returns true on error.
bool checkIfDifferentiableProgrammingEnabled(
ASTContext &ctx, DeclAttribute *attr) {
auto &diags = ctx.Diags;
// The experimental differentiable programming flag must be enabled.
if (!ctx.LangOpts.EnableExperimentalDifferentiableProgramming) {
diags
.diagnose(attr->getLocation(),
diag::experimental_differentiable_programming_disabled)
.highlight(attr->getRangeWithAt());
return true;
}
// The `Differentiable` protocol must be available.
// If unavailable, the `_Differentiation` module should be imported.
if (!ctx.getProtocol(KnownProtocolKind::Differentiable)) {
diags
.diagnose(attr->getLocation(), diag::attr_used_without_required_module,
attr, ctx.Id_Differentiation)
.highlight(attr->getRangeWithAt());
return true;
}
return false;
}

llvm::Expected<IndexSubset *> DifferentiableAttributeTypeCheckRequest::evaluate(
Evaluator &evaluator, DifferentiableAttr *attr) const {
// Skip type-checking for implicit `@differentiable` attributes. We currently
Expand All @@ -3824,21 +3849,8 @@ llvm::Expected<IndexSubset *> DifferentiableAttributeTypeCheckRequest::evaluate(
auto &diags = ctx.Diags;
// `@differentiable` attribute requires experimental differentiable
// programming to be enabled.
if (!ctx.LangOpts.EnableExperimentalDifferentiableProgramming) {
diags
.diagnose(attr->getLocation(),
diag::experimental_differentiable_programming_disabled)
.highlight(attr->getRangeWithAt());
return nullptr;
}
// The `Differentiable` protocol must be available.
if (!ctx.getProtocol(KnownProtocolKind::Differentiable)) {
diags
.diagnose(attr->getLocation(), diag::attr_used_without_required_module,
attr, ctx.Id_Differentiation)
.highlight(attr->getRangeWithAt());
if (checkIfDifferentiableProgrammingEnabled(ctx, attr))
return nullptr;
}

// Derivative registration is disabled for `@differentiable(linear)`
// attributes. Instead, use `@transpose` attribute to register transpose
Expand Down Expand Up @@ -3990,7 +4002,7 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) {
(void)attr->getParameterIndices();
}

/// Typechecks the given derivative attribute `attr` on decl `D`.
/// Type-checks the given `@derivative` attribute `attr` on declaration `D`.
///
/// Effects are:
/// - Sets the original function and parameter indices on `attr`.
Expand All @@ -4000,19 +4012,13 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) {
/// \returns true on error, false on success.
static bool typeCheckDerivativeAttr(ASTContext &Ctx, Decl *D,
DerivativeAttr *attr) {
// Note: Implementation must be idempotent because it can get called multiple
// Note: Implementation must be idempotent because it may be called multiple
// times for the same attribute.

auto &diags = Ctx.Diags;

// `@derivative` attribute requires experimental differentiable programming
// to be enabled.
auto &ctx = D->getASTContext();
if (!ctx.LangOpts.EnableExperimentalDifferentiableProgramming) {
diags.diagnose(attr->getLocation(),
diag::experimental_differentiable_programming_disabled);
if (checkIfDifferentiableProgrammingEnabled(Ctx, attr))
return true;
}
auto *derivative = cast<FuncDecl>(D);
auto lookupConformance =
LookUpConformanceInModule(D->getDeclContext()->getParentModule());
Expand Down Expand Up @@ -4054,19 +4060,6 @@ static bool typeCheckDerivativeAttr(ASTContext &Ctx, Decl *D,
return true;
}
attr->setDerivativeKind(kind);
// `value: R` result tuple element must conform to `Differentiable`.
auto diffableProto = Ctx.getProtocol(KnownProtocolKind::Differentiable);
auto valueResultType = valueResultElt.getType();
if (valueResultType->hasTypeParameter())
valueResultType = derivative->mapTypeIntoContext(valueResultType);
auto valueResultConf = TypeChecker::conformsToProtocol(
valueResultType, diffableProto, derivative->getDeclContext(), None);
if (!valueResultConf) {
diags.diagnose(attr->getLocation(),
diag::derivative_attr_result_value_not_differentiable,
valueResultElt.getType());
return true;
}

// Compute expected original function type and look up original function.
auto *originalFnType =
Expand Down Expand Up @@ -4221,6 +4214,7 @@ static bool typeCheckDerivativeAttr(ASTContext &Ctx, Decl *D,
diffParamTypes);

// Get the differentiability parameters' `TangentVector` associated types.
auto *diffableProto = Ctx.getProtocol(KnownProtocolKind::Differentiable);
auto diffParamTanTypes =
map<SmallVector<TupleTypeElt, 4>>(diffParamTypes, [&](Type paramType) {
if (paramType->hasTypeParameter())
Expand All @@ -4234,7 +4228,19 @@ static bool typeCheckDerivativeAttr(ASTContext &Ctx, Decl *D,
return TupleTypeElt(paramAssocType);
});

// `value: R` result tuple element must conform to `Differentiable`.
// Get the `TangentVector` associated type of the `value:` result type.
auto valueResultType = valueResultElt.getType();
if (valueResultType->hasTypeParameter())
valueResultType = derivative->mapTypeIntoContext(valueResultType);
auto valueResultConf = TypeChecker::conformsToProtocol(
valueResultType, diffableProto, derivative->getDeclContext(), None);
if (!valueResultConf) {
diags.diagnose(attr->getLocation(),
diag::derivative_attr_result_value_not_differentiable,
valueResultElt.getType());
return true;
}
auto resultTanType = valueResultConf.getTypeWitnessByName(
valueResultType, Ctx.Id_TangentVector);

Expand Down
24 changes: 20 additions & 4 deletions test/AutoDiff/Sema/derivative_attr_type_checking.swift
Original file line number Diff line number Diff line change
Expand Up @@ -61,29 +61,45 @@ func vjpSubtractWrt1(x: Float, y: Float) -> (value: Float, pullback: (Float) ->
return (x - y, { $0 })
}

// Test invalid original function.

// expected-error @+1 {{use of unresolved identifier 'nonexistentFunction'}}
@derivative(of: nonexistentFunction)
func vjpOriginalFunctionNotFound(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
fatalError()
}

// Test `@derivative` attribute where `value:` result does not conform to `Differentiable`.
// Invalid original function should be diagnosed first.
// expected-error @+1 {{use of unresolved identifier 'nonexistentFunction'}}
@derivative(of: nonexistentFunction)
func vjpOriginalFunctionNotFound2(_ x: Float) -> (value: Int, pullback: (Float) -> Float) {
fatalError()
}

// Test incorrect `@derivative` declaration type.

// expected-note @+1 {{'incorrectDerivativeType' defined here}}
func incorrectDerivativeType(_ x: Float) -> Float {
return x
}

// expected-error @+1 {{'@derivative(of:)' attribute requires function to return a two-element tuple of type '(value: T..., pullback: (U.TangentVector) -> T.TangentVector...)' or '(value: T..., differential: (T.TangentVector...) -> U.TangentVector)'}}
// expected-error @+1 {{'@derivative(of:)' attribute requires function to return a two-element tuple; first element must have label 'value:' and second element must have label 'pullback:' or 'differential:'}}
@derivative(of: incorrectDerivativeType)
func jvpResultIncorrect(x: Float) -> Float {
return x
}
// expected-error @+1 {{'@derivative(of:)' attribute requires function to return a two-element tuple (first element must have label 'value:'}}
// expected-error @+1 {{'@derivative(of:)' attribute requires function to return a two-element tuple; first element must have label 'value:'}}
@derivative(of: incorrectDerivativeType)
func vjpResultIncorrectFirstLabel(x: Float) -> (Float, (Float) -> Float) {
return (x, { $0 })
}
// expected-error @+1 {{'@derivative(of:)' attribute requires function to return a two-element tuple (second element must have label 'pullback:' or 'differential:')}}
// expected-error @+1 {{'@derivative(of:)' attribute requires function to return a two-element tuple; second element must have label 'pullback:' or 'differential:'}}
@derivative(of: incorrectDerivativeType)
func vjpResultIncorrectSecondLabel(x: Float) -> (value: Float, (Float) -> Float) {
return (x, { $0 })
}
// expected-error @+1 {{'@derivative(of:)' attribute requires function to return a two-element tuple (first element type 'Int' must conform to 'Differentiable')}}
// expected-error @+1 {{could not find function 'incorrectDerivativeType' with expected type '(Int) -> Int'}}
@derivative(of: incorrectDerivativeType)
func vjpResultNotDifferentiable(x: Int) -> (
value: Int, pullback: (Int) -> Int
Expand Down