Skip to content

[AutoDiff] Fix @differentiable(linear) type-checking. #28927

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 22, 2019

Conversation

dan-zheng
Copy link
Contributor

For all non-@noDerivative parameter and result types, @differentiable(linear) function types should require and imply T: Differentiable, T == T.TangentVector requirements instead of T: Differentiable & AdditiveArithmetic.

Update tests.


Type-checking example:

// linear.swift
struct D: Differentiable {}
// `D` does not conform to `AdditiveArithmetic`.
// `D` is not equal to `D.TangentVector`.
func test(_: @differentiable(linear) (D) -> D) {}

Before:

$ swift linear.swift
# Missing error.

After:

$ swift linear.swift
linear.swift:5:39: error: parameter type 'D' does not conform to 'Differentiable' and satisfy 'D == D.TangentVector', but the enclosing function type is '@differentiable(linear)'; did you want to add '@noDerivative' to this parameter?
func test(_: @differentiable(linear) (D) -> D) {}
                                      ^
                                      @noDerivative
linear.swift:5:45: error: result type 'D' does not conform to 'Differentiable' and satisfy 'D == D.TangentVector', but the enclosing function type is '@differentiable(linear)'
func test(_: @differentiable(linear) (D) -> D) {}
                                            ^

Inference example:

func inferred<T, U>(_: @differentiable(linear) (T) -> U) {}

Before:

$ swiftc -print-ast linear.swift
internal func inferred<T, U>(_: @differentiable(linear) (T) -> U) where T : AdditiveArithmetic, T : Differentiable, U : AdditiveArithmetic, U : Differentiable

After:

$ swiftc -print-ast linear.swift
internal func inferred<T, U>(_: @differentiable(linear) (T) -> U) where T : Differentiable, T == T.TangentVector, U : Differentiable, U == U.TangentVector

For all non-`@noDerivative` parameter and result types, `@differentiable(linear)`
function types should require and imply `T: Differentiable`, `T == T.TangentVector`
requirements instead of `T: Differentiable & AdditiveArithmetic`.

Update tests.
@dan-zheng
Copy link
Contributor Author

@swift-ci Please test tensorflow

@rxwei
Copy link
Contributor

rxwei commented Dec 22, 2019

linear.swift:5:39: error: parameter type 'D' does not conform to 'Differentiable' and satisfy 'D == D.TangentVector', but the enclosing function type is '@differentiable(linear)'; did you want to add '@noDerivative' to this parameter?
func test(_: @differentiable(linear) (D) -> D) {}
                                      ^
                                      @noDerivative
linear.swift:5:45: error: result type 'D' does not conform to 'Differentiable' and satisfy 'D == D.TangentVector', but the enclosing function type is '@differentiable(linear)'
func test(_: @differentiable(linear) (D) -> D) {}
                                            ^

Two issues:

  • Why are there duplicate errors?
  • The fixit is not appropriate the there's only one argument, because the function will still be illformed when you add @noDerivative.

@dan-zheng
Copy link
Contributor Author

dan-zheng commented Dec 22, 2019

  • Why are there duplicate errors?

The errors are not duplicate: one's for the invalid parameter and one's for the invalid result.

  • The fixit is not appropriate the there's only one argument, because the function will still be illformed when you add @noDerivative.

This was the existing behavior. We could fix the fix-it so it appears only if at least one differentiability parameter remains after adding @noDerivative.

Edit: fixed in 642f2e4. Emit @noDerivative fix-it only when there is at least one valid
differentiability/linearity parameter.

Copy link
Contributor

@rxwei rxwei left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not adding T.TangentVector == T was by design, but adding this constraint is okay with me. Did you encounter any concrete use cases where the absence of T.TangentVector == T requires extra user attention?

@dan-zheng
Copy link
Contributor Author

dan-zheng commented Dec 22, 2019

Not adding T.TangentVector == T was by design, but adding this constraint is okay with me. Did you encounter any concrete use cases where the absence of T.TangentVector == T requires extra user attention?

Ah, interesting. I assumed that T == T.TangentVector is necessary since @transpose attribute type-checking requires it for linearity parameters - this patch makes @differentiable(linear) type-checking consistent. But I guess it may not be necessary.


Do we want to consistently require T: Differentiable, T == T.TangentVector for @transpose and @differentiable(linear) function types, or to relax it to T: Differentiable & AdditiveArithmetic?

Relaxing the requirement to make @differentiable(linear) function types work with more types seems preferable to me, if possible. For example, relaxing the requirement makes default transposes for protocol requirements like AdditiveArithmetic.+ work for more conforming types.

@dan-zheng
Copy link
Contributor Author

Consensus after chatting: let's consistently require T == T.TangentVector for now.

Emit `@noDerivative` fixit only when there is at least one valid
differentiability/linearity parameter. Otherwise, adding `@noDerivative`
results in an ill-formed `@differentiable` function type.
@dan-zheng
Copy link
Contributor Author

@swift-ci Please test tensorflow

@dan-zheng dan-zheng merged commit d10f51a into swiftlang:tensorflow Dec 22, 2019
@dan-zheng dan-zheng deleted the fix-diff(linear)-inference branch December 22, 2019 22:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
tensorflow This is for "tensorflow" branch PRs.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants