Skip to content

[AutoDiff] [Docs] Clarify 'Differentiable' derived conformances conditions. #33927

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 118 additions & 31 deletions docs/DifferentiableProgramming.md
Original file line number Diff line number Diff line change
Expand Up @@ -1219,53 +1219,73 @@ network layers and models are formed from smaller components stored as
properties in structure types and class types. In order to use these types for
differentiation, one must extend these types to conform to the `Differentiable`
protocol. Luckily, this need not be done manually in most cases—the compiler
automatically synthesizes conformances when a memberwise `Differentiable`
conformance is declared.
automatically synthesizes conformances when a `Differentiable` conformance is
declared.

##### Synthesis conditions

The compiler automatically synthesizes implementations of `Differentiable`
protocol requirements for struct and class types. Here are the conditions for
synthesis: The type must declare a conformance to `Differentiable` with a
`@memberwise` attribute before the protocol name, either on the type declaration
or on an extension in the same file. All stored properties of the conforming
type must either be a `var` that conforms to `Differentiable` or be marked with
the `@noDerivative` attribute. If a non-`Differentiable` or a `let` stored
property is not marked with `@noDerivative`, then it is treated as if it has
`@noDerivative` and the compiler emits a warning (with a fix-it in IDEs) asking
the user to make the attribute explicit.
protocol requirements for struct and class types. For a type, conditions for the
synthesis are:

1. There is a conformance to `Differentiable` declared for the type, either in
the original type declaration or in an extension.

2. There is a `@memberwise` attribute in the conformance clause before the
protocol name.

3. The conformance must be declared in the same file.

Here is an example where the synthesis conditions are satisfied.

```swift
struct Model: @memberwise Differentiable {
var weight: SIMD4<Double>
var bias: Double
let metadata1: Float
let metadata2: Float
let usesBias: Bool
}
```

##### Default synthesis

By default, the compiler synthesizes a nested `TangentVector` structure type
that contains the `TangentVector`s of all stored properties that are not marked
with `@noDerivative`. In other words, `@noDerivative` makes a stored property
not be included in a type's tangent vectors.
The compiler synthesizes a nested `TangentVector` structure type that contains
the `TangentVector`s of all stored properties (terms and conditions apply) that
conform to `Differentiable`, which we call **differentiable variables**.

Mathematically, the synthesized implementation treats the data structure as a
product manifold of the manifolds each differentiable variable's type
represents. Differentiable variables' types are required to conform to
`Differentiable` because the synthesized implementation needs to access each
differentiable variable's type's `TangentVector` associated type and invoke each
differentiable variable's implementation of `move(along:)` and
`zeroTangentVectorInitializer`. Because the synthesized implementation needs to
invoke `move(along:)` on each differentiable variable, the differentiable
variables must have a `move(along:)` which satisfies the protocol requirement
and can be invoked on the property. That is, the property must be either a
variable (`var`) or a constant (`let`) with a non-`mutating` implementation of
the `move(along:)` protocol requirement.

The synthesized `TangentVector` has the same effective access level as the
original type declaration. Properties in the synthesized `TangentVector` have
the same effective access level as their corresponding original properties.

A `move(along:)` method is synthesized with a body that calls `move(along:)` for
each pair of the original property and its corresponding property in
`TangentVector`.
The synthesized `move(along:)` method calls `move(along:)` for each pair of a
differentiable variable and its corresponding property in `TangentVector`.

Similarly, when memberwise derivation is possible,
`zeroTangentVectorInitializer` is synthesized to return a closure that captures
and calls each stored property's `zeroTangentVectorInitializer` closure.
When memberwise derivation is not possible (e.g. for custom user-defined
`TangentVector` types), `zeroTangentVectorInitializer` is synthesized as a
`{ TangentVector.zero }` closure.

Here's an example:
The synthesized `zeroTangentVectorInitializer` property returns a closure that
captures and calls each stored property's `zeroTangentVectorInitializer`
closure. When memberwise derivation is not possible (e.g. for custom
user-defined `TangentVector` types), `zeroTangentVectorInitializer` is
synthesized as a `{ TangentVector.zero }` closure.

```swift
struct Foo<T: Differentiable, U: Differentiable>: @memberwise Differentiable {
// `x` and `y` are the "differentiation properties".
// `x` and `y` are the "differentiable variables".
var x: T
var y: U
@noDerivative var customFlag: Bool
@noDerivative let helperVariable: T
let customFlag: Bool

// The compiler synthesizes:
//
Expand All @@ -1279,7 +1299,6 @@ struct Foo<T: Differentiable, U: Differentiable>: @memberwise Differentiable {
// y.move(along: direction.y)
// }
//
// @noDerivative
// var zeroTangentVectorInitializer: () -> TangentVector {
// { [xTanInit = x.zeroTangentVectorInitializer,
// yTanInit = y.zeroTangentVectorInitializer] in
Expand All @@ -1289,6 +1308,74 @@ struct Foo<T: Differentiable, U: Differentiable>: @memberwise Differentiable {
}
```

###### Opt out of synthesis for a stored property

The synthesized implementation of `Differentiable` protocol requirements already
excludes stored properties that are not differentiable variables, such as stored
properties that do not conform to `Differentiable` and `let`
properties that do not have a non-mutating `move(along:)`. In addition to this
behavior, we also introduce a `@noDerivative` declaration attribute, which can
be attached to properties that the programmer does not wish to include in the
synthesized `Differentiable` protocol requirement implementation.

When a stored property is marked with `@noDerivative` in a type that declares a
conformance to `Differentiable`, it will not be treated as a differentiable
variable regardless of whether it conforms to `Differentiable`. That is, the
synthesized implementation of protocol requirements will not include this
property.

```swift
struct Foo<T: Differentiable, U: Differentiable>: @memberwise Differentiable {
// `x` and `y` are the "differentiable variables".
var x: T
var y: U
@noDerivative var customFlag: Bool
@noDerivative let helperVariable: T
}
```

For clarity as to which stored properties are to be included for
differentiation, the compiler will recommend that all stored properties that
cannot be included as differentiable variables (due to either lacking a
conformance to `Differentiable` or being a non-`class`-bound `let` property) be
marked with `@noDerivative`. When a property is not included as a differentiable
variable and is not marked with `@noDerivative`, the compiler produces a warning
as asking the user to make the exclusion explicit along with fix-it suggestions
in IDEs.

```swift
struct Foo<T: Differentiable, U: Differentiable>: @memberwise Differentiable {
// `x` and `y` are the "differentiable variables".
var x: T
var y: U
var customFlag: Bool
let helperVariable: T
}
```

```console
test.swift:5:4: warning: stored property 'customFlag' has no derivative because 'Bool' does not conform to 'Differentiable'
var customFlag: Bool

test.swift:5:4: note: add a '@noDerivative' attribute to make it explicit
var customFlag: Bool
^
@noDerivative

test.swift:6:4: warning: synthesis of the 'Differentiable.move(along:)' requirement for 'Foo' requires all stored properties not marked with `@noDerivative` to be mutable
let helperVariable: T

test.swift:6:4: note: change 'let' to 'var' to make it mutable
let helperVariable: T
^~~
var

test.swift:6:4: note: add a '@noDerivative' attribute to make it explicit
let helperVariable: T
^
@noDerivative
```

##### Shortcut synthesis

In certain cases, it is not ideal to keep `Self` and `TangentVector` as separate
Expand All @@ -1302,7 +1389,7 @@ Method `move(along:)` will not be synthesized because a default implementation
already exists.

```swift
struct Point<T: Real>: @memberwise Differentiable, @memberwise AdditiveArithmetic {
struct Point<T: Real>: @memberwise Differentiable, AdditiveArithmetic {
// `x` and `y` are the "differentiation properties".
var x, y: T

Expand Down