Skip to content

Commit eccf9b3

Browse files
authored
Merge pull request #33927 from rxwei/clarify-derived-conformances-conditions
2 parents 9010c1b + 8df2d34 commit eccf9b3

File tree

1 file changed

+117
-30
lines changed

1 file changed

+117
-30
lines changed

docs/DifferentiableProgramming.md

Lines changed: 117 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1219,53 +1219,73 @@ network layers and models are formed from smaller components stored as
12191219
properties in structure types and class types. In order to use these types for
12201220
differentiation, one must extend these types to conform to the `Differentiable`
12211221
protocol. Luckily, this need not be done manually in most cases—the compiler
1222-
automatically synthesizes conformances when a memberwise `Differentiable`
1223-
conformance is declared.
1222+
automatically synthesizes conformances when a `Differentiable` conformance is
1223+
declared.
12241224

12251225
##### Synthesis conditions
12261226

12271227
The compiler automatically synthesizes implementations of `Differentiable`
1228-
protocol requirements for struct and class types. Here are the conditions for
1229-
synthesis: The type must declare a conformance to `Differentiable` with a
1230-
`@memberwise` attribute before the protocol name, either on the type declaration
1231-
or on an extension in the same file. All stored properties of the conforming
1232-
type must either be a `var` that conforms to `Differentiable` or be marked with
1233-
the `@noDerivative` attribute. If a non-`Differentiable` or a `let` stored
1234-
property is not marked with `@noDerivative`, then it is treated as if it has
1235-
`@noDerivative` and the compiler emits a warning (with a fix-it in IDEs) asking
1236-
the user to make the attribute explicit.
1228+
protocol requirements for struct and class types. For a type, conditions for the
1229+
synthesis are:
1230+
1231+
1. There is a conformance to `Differentiable` declared for the type, either in
1232+
the original type declaration or in an extension.
1233+
1234+
2. There is a `@memberwise` attribute in the conformance clause before the
1235+
protocol name.
1236+
1237+
3. The conformance must be declared in the same file.
1238+
1239+
Here is an example where the synthesis conditions are satisfied.
1240+
1241+
```swift
1242+
struct Model: @memberwise Differentiable {
1243+
var weight: SIMD4<Double>
1244+
var bias: Double
1245+
let metadata1: Float
1246+
let metadata2: Float
1247+
let usesBias: Bool
1248+
}
1249+
```
12371250

12381251
##### Default synthesis
12391252

1240-
By default, the compiler synthesizes a nested `TangentVector` structure type
1241-
that contains the `TangentVector`s of all stored properties that are not marked
1242-
with `@noDerivative`. In other words, `@noDerivative` makes a stored property
1243-
not be included in a type's tangent vectors.
1253+
The compiler synthesizes a nested `TangentVector` structure type that contains
1254+
the `TangentVector`s of all stored properties (terms and conditions apply) that
1255+
conform to `Differentiable`, which we call **differentiable variables**.
1256+
1257+
Mathematically, the synthesized implementation treats the data structure as a
1258+
product manifold of the manifolds each differentiable variable's type
1259+
represents. Differentiable variables' types are required to conform to
1260+
`Differentiable` because the synthesized implementation needs to access each
1261+
differentiable variable's type's `TangentVector` associated type and invoke each
1262+
differentiable variable's implementation of `move(along:)` and
1263+
`zeroTangentVectorInitializer`. Because the synthesized implementation needs to
1264+
invoke `move(along:)` on each differentiable variable, the differentiable
1265+
variables must have a `move(along:)` which satisfies the protocol requirement
1266+
and can be invoked on the property. That is, the property must be either a
1267+
variable (`var`) or a constant (`let`) with a non-`mutating` implementation of
1268+
the `move(along:)` protocol requirement.
12441269

12451270
The synthesized `TangentVector` has the same effective access level as the
12461271
original type declaration. Properties in the synthesized `TangentVector` have
12471272
the same effective access level as their corresponding original properties.
12481273

1249-
A `move(along:)` method is synthesized with a body that calls `move(along:)` for
1250-
each pair of the original property and its corresponding property in
1251-
`TangentVector`.
1274+
The synthesized `move(along:)` method calls `move(along:)` for each pair of a
1275+
differentiable variable and its corresponding property in `TangentVector`.
12521276

1253-
Similarly, when memberwise derivation is possible,
1254-
`zeroTangentVectorInitializer` is synthesized to return a closure that captures
1255-
and calls each stored property's `zeroTangentVectorInitializer` closure.
1256-
When memberwise derivation is not possible (e.g. for custom user-defined
1257-
`TangentVector` types), `zeroTangentVectorInitializer` is synthesized as a
1258-
`{ TangentVector.zero }` closure.
1259-
1260-
Here's an example:
1277+
The synthesized `zeroTangentVectorInitializer` property returns a closure that
1278+
captures and calls each stored property's `zeroTangentVectorInitializer`
1279+
closure. When memberwise derivation is not possible (e.g. for custom
1280+
user-defined `TangentVector` types), `zeroTangentVectorInitializer` is
1281+
synthesized as a `{ TangentVector.zero }` closure.
12611282

12621283
```swift
12631284
struct Foo<T: Differentiable, U: Differentiable>: @memberwise Differentiable {
1264-
// `x` and `y` are the "differentiation properties".
1285+
// `x` and `y` are the "differentiable variables".
12651286
var x: T
12661287
var y: U
1267-
@noDerivative var customFlag: Bool
1268-
@noDerivative let helperVariable: T
1288+
let customFlag: Bool
12691289

12701290
// The compiler synthesizes:
12711291
//
@@ -1279,7 +1299,6 @@ struct Foo<T: Differentiable, U: Differentiable>: @memberwise Differentiable {
12791299
// y.move(along: direction.y)
12801300
// }
12811301
//
1282-
// @noDerivative
12831302
// var zeroTangentVectorInitializer: () -> TangentVector {
12841303
// { [xTanInit = x.zeroTangentVectorInitializer,
12851304
// yTanInit = y.zeroTangentVectorInitializer] in
@@ -1289,6 +1308,74 @@ struct Foo<T: Differentiable, U: Differentiable>: @memberwise Differentiable {
12891308
}
12901309
```
12911310

1311+
###### Opt out of synthesis for a stored property
1312+
1313+
The synthesized implementation of `Differentiable` protocol requirements already
1314+
excludes stored properties that are not differentiable variables, such as stored
1315+
properties that do not conform to `Differentiable` and `let`
1316+
properties that do not have a non-mutating `move(along:)`. In addition to this
1317+
behavior, we also introduce a `@noDerivative` declaration attribute, which can
1318+
be attached to properties that the programmer does not wish to include in the
1319+
synthesized `Differentiable` protocol requirement implementation.
1320+
1321+
When a stored property is marked with `@noDerivative` in a type that declares a
1322+
conformance to `Differentiable`, it will not be treated as a differentiable
1323+
variable regardless of whether it conforms to `Differentiable`. That is, the
1324+
synthesized implementation of protocol requirements will not include this
1325+
property.
1326+
1327+
```swift
1328+
struct Foo<T: Differentiable, U: Differentiable>: @memberwise Differentiable {
1329+
// `x` and `y` are the "differentiable variables".
1330+
var x: T
1331+
var y: U
1332+
@noDerivative var customFlag: Bool
1333+
@noDerivative let helperVariable: T
1334+
}
1335+
```
1336+
1337+
For clarity as to which stored properties are to be included for
1338+
differentiation, the compiler will recommend that all stored properties that
1339+
cannot be included as differentiable variables (due to either lacking a
1340+
conformance to `Differentiable` or being a non-`class`-bound `let` property) be
1341+
marked with `@noDerivative`. When a property is not included as a differentiable
1342+
variable and is not marked with `@noDerivative`, the compiler produces a warning
1343+
asking the user to make the exclusion explicit along with fix-it suggestions in
1344+
IDEs.
1345+
1346+
```swift
1347+
struct Foo<T: Differentiable, U: Differentiable>: @memberwise Differentiable {
1348+
// `x` and `y` are the "differentiable variables".
1349+
var x: T
1350+
var y: U
1351+
var customFlag: Bool
1352+
let helperVariable: T
1353+
}
1354+
```
1355+
1356+
```console
1357+
test.swift:5:4: warning: stored property 'customFlag' has no derivative because 'Bool' does not conform to 'Differentiable'
1358+
var customFlag: Bool
1359+
1360+
test.swift:5:4: note: add a '@noDerivative' attribute to make it explicit
1361+
var customFlag: Bool
1362+
^
1363+
@noDerivative
1364+
1365+
test.swift:6:4: warning: synthesis of the 'Differentiable.move(along:)' requirement for 'Foo' requires all stored properties not marked with `@noDerivative` to be mutable
1366+
let helperVariable: T
1367+
1368+
test.swift:6:4: note: change 'let' to 'var' to make it mutable
1369+
let helperVariable: T
1370+
^~~
1371+
var
1372+
1373+
test.swift:6:4: note: add a '@noDerivative' attribute to make it explicit
1374+
let helperVariable: T
1375+
^
1376+
@noDerivative
1377+
```
1378+
12921379
##### Shortcut synthesis
12931380

12941381
In certain cases, it is not ideal to keep `Self` and `TangentVector` as separate

0 commit comments

Comments
 (0)