Skip to content

[AutoDiff] Support derivative registration for more declaration kinds. #28468

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

dan-zheng
Copy link
Contributor

Make @differentiating and @transposing attributes support more original
declaration kinds: computed properties, subscripts, and initializers.

  • Change DifferentiatingAttr and TransposingAttr to store an
    AbstractFunctionDecl representing the original declaration, instead of a
    FuncDecl.
  • Change attribute parsing to support initializer/subscript DeclNames.
  • Make TypeChecker::lookupFuncDecl a static function in TypeCheckAttr.cpp.
  • Assorted parsing and type-checking gardening.

@differentiating now has feature parity with
@differentiable(jvp: ..., vjp: ...) for derivative registration.

This is a necessary step towards making @differentiating and @transposing
the canonical mechanism for registering derivative/transpose functions.

Registering non-func declaration derivatives with @differentiable attribute
jvp:/vjp: labels is now explicitly rejected.

Resolves TF-281.

Todos:

  • Upstream changes to previously upstreamed code.
  • TF-997: support @transposing attribute with initializer original
    declarations.
  • TF-988: do not reuse @differentiable attribute type-checking diagnostics
    for @differentiating/@transposing attribute type-checking.

Newly supported:

struct Foo: Differentiable {}

// Computed properties.
extension Foo {
  var computedProperty: Float { 1 }

  @differentiating(computedProperty)
  func vjpComputedProperty() -> (value: Float, pullback: (Float) -> TangentVector) { ... }
}

// Initializers.
extension Foo {
  init(_ x: Float) {}

  @differentiating(init, wrt: x)
  static func vjpInit(_ x: Float) -> (value: Foo, pullback: (TangentVector) -> Float) { ... }
}

// Subscripts.
extension Foo {
  subscript() -> Float { 1 }

  @differentiating(subscript)
  func vjpSubscript() -> (value: Float, pullback: (Float) -> TangentVector) { ... }
}

Newly diagnosed:

struct Bar: Differentiable {
  @differentiable(vjp: computedPropertyVJP)
  func instanceMethod() -> Float { 1 }

  var computedPropertyVJP: (Float, (Float) -> TangentVector) { ... }
}

// error: registered derivative 'computedPropertyVJP' must be a 'func' declaration
//   @differentiable(vjp: computedPropertyVJP)
//                        ^

Make `@differentiating` and `@transposing` attributes support more original
declaration kinds: computed properties, subscripts, and initializers.

- Change `DifferentiatingAttr` and `TransposingAttr` to store an
  `AbstractFunctionDecl` representing the original declaration, instead of a
  `FuncDecl`.
- Change attribute parsing to support initializer/subscript `DeclName`s.
- Make `TypeChecker::lookupFuncDecl` a static function in TypeCheckAttr.cpp.
- Assorted parsing and type-checking gardening.

`@differentiating` now has feature parity with
`@differentiable(jvp: ..., vjp: ...)` for derivative registration.

This is a necessary step towards making `@differentiating` and `@transposing`
the canonical mechanism for registering derivative/transpose functions.

Registering non-`func` declaration derivatives with `@differentiable` attribute
`jvp:`/`vjp:` labels is now explicitly rejected.

Resolves TF-281.

Todos:
- TF-997: support `@transposing` attribute with initializer original
  declarations.
- TF-988: do not reuse `@differentiable` attribute type-checking diagnostics
  for `@differentiating`/`@transposing` attribute type-checking.
@dan-zheng dan-zheng added the tensorflow This is for "tensorflow" branch PRs. label Nov 25, 2019
@dan-zheng
Copy link
Contributor Author

@swift-ci Please test tensorflow

@dan-zheng
Copy link
Contributor Author

Happy to address feedback later!

@dan-zheng dan-zheng merged commit 658b7f7 into swiftlang:tensorflow Nov 25, 2019
@dan-zheng dan-zheng deleted the register-derivatives-for-more-decls branch November 25, 2019 17:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
tensorflow This is for "tensorflow" branch PRs.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants