Skip to content

[API] [AD] Revamp @differentiable usages in stdlib. #21732

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

dan-zheng
Copy link
Contributor

@dan-zheng dan-zheng commented Jan 9, 2019

  • Use FloatingPoint rather than BinaryFloatingPoint to constrain
    differentiability.
  • Use @differentiable where clauses to constrain differentiability
    of numeric operations.
    • The most common constraint is where Scalar : FloatingPoint because
      Tensor conditionally conforms to Differentiable where Scalar : FloatingPoint.
  • Tensor now conditionally conforms to Differentiable where Scalar : Differentiable & FloatingPoint.
  • Allow @differentiable where clause conformance requirements to protocol composition types.
  • Make VJP applications use the correct substitution map.
    • If a custom @differentiable attribute defines a VJP and where clause
      requirements, VJP applications should use a substitution map involving
      those requirements.
    • Note: more related cases need to be handled, such as @differentiable
      attributes with where clause requirements but no VJP. These cases will
      be handled later.

Todos:

  • Make more Tensor operations differentiable.
    • This includes reduction and broadcasting ops.
    • This is enabled by @differentiable where clause type-checking.
  • Use VJP functions instead of adjoint functions.
    • I've started work on this.
    • I would prefer that this be done in a separate patch, after this patch
      adds the correct @differentiable where clauses.
  • Add tests for newly @differentiable Tensor operations.

@dan-zheng dan-zheng added the tensorflow This is for "tensorflow" branch PRs. label Jan 9, 2019
@dan-zheng dan-zheng requested a review from rxwei January 9, 2019 05:11
@dan-zheng
Copy link
Contributor Author

WIP, triggering tests to see what breaks.
@swift-ci Please test tensorflow

@rxwei
Copy link
Contributor

rxwei commented Jan 9, 2019

I think we should change Tensor's Differentiable conformance condition first. Currently it is

extension Tensor : Differentiable where Scalar : FloatingPoint

This kinda works, but what if, in a context where Scalar : FloatingPoint, you want to differentiate through some scalar operation on an element of Tensor<Scalar>? In that case, you'll want the scalar to be differentiable as well. Therefore, we should change the conformance condition to:

extension Tensor : Differentiable where Scalar : Differentiable & FloatingPoint

- Use `FloatingPoint` rather than `BinaryFloatingPoint` to constrain
  differentiability.
  - Follows from:
    - swiftlang#21673
    - tensorflow/swift-bindings#11
- Use `@differentiable` where clauses to constrain differentiability
  of numeric operations.
  - The most common constraint is `where Scalar : FloatingPoint` because
    `Tensor` conditionally conforms to `Differentiable where Scalar : FloatingPoint`.

Todos:
- Make more `Tensor` operations differentiable.
  - This includes reduction and broadcasting ops.
  - This is enabled by `@differentiable` where clause type-checking.
- Use VJP functions instead of adjoint functions.
  - I would prefer that this be done in a separate patch, after this patch
    adds the correct `@differentiable` where clauses.
- Add tests for newly `@differentiable` `Tensor` operations.
If a custom `@differentiable` attribute defines a VJP and where clause
requirements, VJP applications should use a substitution map involving
those requirements.

Note: more related cases need to be handled, such as `@differentiable`
attributes with where clause requirements but no VJP. These cases will
be handled later.
@dan-zheng dan-zheng force-pushed the revamp-stdlib-differentiable branch from aa7b5af to cba5a2b Compare January 9, 2019 09:11
@dan-zheng
Copy link
Contributor Author

Therefore, we should change the conformance condition to:

extension Tensor : Differentiable where Scalar : Differentiable & FloatingPoint

This makes sense. I'll adopt this change in this PR.

…tiable`.

`Tensor` now conditionally conforms to `Differentiable` where
`Scalar : Differentiable & FloatingPoint`.

All `@differentiable` where clauses and adjoint definitions have been updated
accordingly.

Allow `@differentiable` where clause conformance requirements to protocol
composition types.
@dan-zheng dan-zheng force-pushed the revamp-stdlib-differentiable branch from 214327a to 96fffb2 Compare January 9, 2019 09:53
@dan-zheng
Copy link
Contributor Author

@swift-ci Please test tensorflow

@dan-zheng
Copy link
Contributor Author

Merging to unblock progress.

@dan-zheng dan-zheng merged commit 134500f into swiftlang:tensorflow Jan 9, 2019
@rxwei
Copy link
Contributor

rxwei commented Jan 9, 2019

Would you mind changing the PR description to mention the substitution map change?

@dan-zheng
Copy link
Contributor Author

Would you mind changing the PR description to mention the substitution map change?

Done! Also added info about the last commit (Tensor : Differentiable conditional conformance change).

@rxwei
Copy link
Contributor

rxwei commented Jan 9, 2019

Note: more related cases need to be handled, such as @differentiable attributes with where clause requirements but no VJP. These cases will be handled later.

Yeah, this will be handled (and required) in differentiation through generics.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
tensorflow This is for "tensorflow" branch PRs.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants