-
Notifications
You must be signed in to change notification settings - Fork 10.5k
[AutoDiff] Improve @derivative
attribute diagnostics.
#29918
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AutoDiff] Improve @derivative
attribute diagnostics.
#29918
Conversation
Previously, `@derivative` attribute type-checking produced a confusing error referencing unbound types `T` and `U`: ``` '@Derivative(of:)' attribute requires function to return a two-element tuple of type '(value: T..., pullback: (U.TangentVector) -> T.TangentVector...)' or '(value: T..., differential: (T.TangentVector...) -> U.TangentVector)' ``` Now, the error is less confusing: ``` '@Derivative(of:)' attribute requires function to return a two-element tuple; first element must have label 'value:' and second element must have label 'pullback:' or 'differential:' ```
ERROR(derivative_attr_invalid_result_tuple_value_label,none, | ||
"'@derivative(of:)' attribute requires function to return a two-element " | ||
"tuple (first element must have label 'value:')", ()) | ||
"tuple; first element must have label 'value:'", ()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: derivative_attr_expected_result_tuple
(1) is somewhat duplicated by derivative_attr_invalid_result_tuple_value_label
(2) and derivative_attr_invalid_result_tuple_func_label
(3).
(1) is produced when the @derivative
function type doesn't return a two-element tuple. (2) and (3) are produced when the returned two-element tuple's labels are incorrect.
Unless someone has suggestions, I'm content to leave (2) and (3) as is, since they're more specific than (1).
()) | ||
ERROR(derivative_attr_result_value_not_differentiable,none, | ||
"'@derivative(of:)' attribute requires function to return a two-element " | ||
"tuple (first element type %0 must conform to 'Differentiable')", (Type)) | ||
"tuple; first element type %0 must conform to 'Differentiable'", (Type)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why's "conform to Differentiable
" special here in a derivative function? Would it be better to just emit a diagnostic that suggests the expected derivative type?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it be better to just emit a diagnostic that suggests the expected derivative type?
That would be ideal. However, for @derivative
attribute type-checking, the "expected derivative type" is not known - we start with the type of the @derivative
declaration and try to compute the appropriate original function type.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, but I feel the "must conform to Differentiable
" error message is making it worse because it seems to suggest that derivative function's value:
result is somehow special with regards to Differentiable
conformance, while it's in fact simply the same as the original function's result. Just my two cents. I think a better diagnostic could be:
'value:' element type %0 must be the same as the result of the original function being differentiated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand your point. I pushed some further diagnostics improvements along this direction.
Now, original function lookup occurs before the "does first element type conform to Differentiable
?" check. In effect, this delays "find Differentiable
conformance for value:
result" as late as possible, so original function lookup diagnostics trigger first.
Example:
import _Differentiation
@derivative(of: nonexistentFunction)
func derivative(_ x: Float) -> (value: Int, pullback: (Float) -> Float) {
fatalError()
}
Before: unideal diagnostic about value:
result.
derivative.swift:2:2: error: '@derivative(of:)' attribute requires function to return a two-element tuple (first element type 'Int' must conform to 'Differentiable')
@derivative(of: nonexistentFunction)
^
After: ideal diagnostic about original function not found.
derivative.swift:2:17: error: use of unresolved identifier 'nonexistentFunction'
@derivative(of: nonexistentFunction)
^
Example:
import _Differentiation
func original(_ x: Int) -> Int { x }
@derivative(of: original)
func vjpOriginalFunctionNotFound2(_ x: Float) -> (value: Int, pullback: (Float) -> Float) {
fatalError()
}
Before: unideal diagnostic about value:
result.
derivative2.swift:5:2: error: '@derivative(of:)' attribute requires function to return a two-element tuple (first element type 'Int' must conform to 'Differentiable')
@derivative(of: original)
^
After: ideal diagnostic about original function not found.
derivative2.swift:5:17: error: could not find function 'original' with expected type '(Float) -> Int'
@derivative(of: original)
^
Let me know if these changes make sense, and if you'd like further changes! I'd like to defer major diagnostic changes until later to unblock progress.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is very nice!
Edit: oops, accidentally triggered CI before seeing feedback above. |
Attempt to look up original function before checking whether the `value:` result conforms to `Differentiable`. This improves diagnostics: "original function not found" should be diagnosed as early as possible.
@swift-ci Please smoke test and merge |
Previously,
@derivative
attribute type-checking produced a confusing errorreferencing unbound types
T
andU
:Now, the error is less confusing:
Attempt to look up original function before checking whether the
value:
resultconforms to
Differentiable
.This improves diagnostics: "original function not found" should be diagnosed as
early as possible.
Add utility for checking whether differentiable programming is enabled.