@@ -1219,53 +1219,73 @@ network layers and models are formed from smaller components stored as
1219
1219
properties in structure types and class types. In order to use these types for
1220
1220
differentiation, one must extend these types to conform to the ` Differentiable `
1221
1221
protocol. Luckily, this need not be done manually in most cases—the compiler
1222
- automatically synthesizes conformances when a memberwise ` Differentiable `
1223
- conformance is declared.
1222
+ automatically synthesizes conformances when a ` Differentiable ` conformance is
1223
+ declared.
1224
1224
1225
1225
##### Synthesis conditions
1226
1226
1227
1227
The compiler automatically synthesizes implementations of ` Differentiable `
1228
- protocol requirements for struct and class types. Here are the conditions for
1229
- synthesis: The type must declare a conformance to ` Differentiable ` with a
1230
- ` @memberwise ` attribute before the protocol name, either on the type declaration
1231
- or on an extension in the same file. All stored properties of the conforming
1232
- type must either be a ` var ` that conforms to ` Differentiable ` or be marked with
1233
- the ` @noDerivative ` attribute. If a non-` Differentiable ` or a ` let ` stored
1234
- property is not marked with ` @noDerivative ` , then it is treated as if it has
1235
- ` @noDerivative ` and the compiler emits a warning (with a fix-it in IDEs) asking
1236
- the user to make the attribute explicit.
1228
+ protocol requirements for struct and class types. For a type, conditions for the
1229
+ synthesis are:
1230
+
1231
+ 1 . There is a conformance to ` Differentiable ` declared for the type, either in
1232
+ the original type declaration or in an extension.
1233
+
1234
+ 2 . There is a ` @memberwise ` attribute in the conformance clause before the
1235
+ protocol name.
1236
+
1237
+ 3 . The conformance must be declared in the same file.
1238
+
1239
+ Here is an example where the synthesis conditions are satisfied.
1240
+
1241
+ ``` swift
1242
+ struct Model : @memberwise Differentiable {
1243
+ var weight: SIMD4<Double >
1244
+ var bias: Double
1245
+ let metadata1: Float
1246
+ let metadata2: Float
1247
+ let usesBias: Bool
1248
+ }
1249
+ ```
1237
1250
1238
1251
##### Default synthesis
1239
1252
1240
- By default, the compiler synthesizes a nested ` TangentVector ` structure type
1241
- that contains the ` TangentVector ` s of all stored properties that are not marked
1242
- with ` @noDerivative ` . In other words, ` @noDerivative ` makes a stored property
1243
- not be included in a type's tangent vectors.
1253
+ The compiler synthesizes a nested ` TangentVector ` structure type that contains
1254
+ the ` TangentVector ` s of all stored properties (terms and conditions apply) that
1255
+ conform to ` Differentiable ` , which we call ** differentiable variables** .
1256
+
1257
+ Mathematically, the synthesized implementation treats the data structure as a
1258
+ product manifold of the manifolds each differentiable variable's type
1259
+ represents. Differentiable variables' types are required to conform to
1260
+ ` Differentiable ` because the synthesized implementation needs to access each
1261
+ differentiable variable's type's ` TangentVector ` associated type and invoke each
1262
+ differentiable variable's implementation of ` move(along:) ` and
1263
+ ` zeroTangentVectorInitializer ` . Because the synthesized implementation needs to
1264
+ invoke ` move(along:) ` on each differentiable variable, the differentiable
1265
+ variables must have a ` move(along:) ` which satisfies the protocol requirement
1266
+ and can be invoked on the property. That is, the property must be either a
1267
+ variable (` var ` ) or a constant (` let ` ) with a non-` mutating ` implementation of
1268
+ the ` move(along:) ` protocol requirement.
1244
1269
1245
1270
The synthesized ` TangentVector ` has the same effective access level as the
1246
1271
original type declaration. Properties in the synthesized ` TangentVector ` have
1247
1272
the same effective access level as their corresponding original properties.
1248
1273
1249
- A ` move(along:) ` method is synthesized with a body that calls ` move(along:) ` for
1250
- each pair of the original property and its corresponding property in
1251
- ` TangentVector ` .
1274
+ The synthesized ` move(along:) ` method calls ` move(along:) ` for each pair of a
1275
+ differentiable variable and its corresponding property in ` TangentVector ` .
1252
1276
1253
- Similarly, when memberwise derivation is possible,
1254
- ` zeroTangentVectorInitializer ` is synthesized to return a closure that captures
1255
- and calls each stored property's ` zeroTangentVectorInitializer ` closure.
1256
- When memberwise derivation is not possible (e.g. for custom user-defined
1257
- ` TangentVector ` types), ` zeroTangentVectorInitializer ` is synthesized as a
1258
- ` { TangentVector.zero } ` closure.
1259
-
1260
- Here's an example:
1277
+ The synthesized ` zeroTangentVectorInitializer ` property returns a closure that
1278
+ captures and calls each stored property's ` zeroTangentVectorInitializer `
1279
+ closure. When memberwise derivation is not possible (e.g. for custom
1280
+ user-defined ` TangentVector ` types), ` zeroTangentVectorInitializer ` is
1281
+ synthesized as a ` { TangentVector.zero } ` closure.
1261
1282
1262
1283
``` swift
1263
1284
struct Foo <T : Differentiable , U : Differentiable >: @memberwise Differentiable {
1264
- // `x` and `y` are the "differentiation properties ".
1285
+ // `x` and `y` are the "differentiable variables ".
1265
1286
var x: T
1266
1287
var y: U
1267
- @noDerivative var customFlag: Bool
1268
- @noDerivative let helperVariable: T
1288
+ let customFlag: Bool
1269
1289
1270
1290
// The compiler synthesizes:
1271
1291
//
@@ -1279,7 +1299,6 @@ struct Foo<T: Differentiable, U: Differentiable>: @memberwise Differentiable {
1279
1299
// y.move(along: direction.y)
1280
1300
// }
1281
1301
//
1282
- // @noDerivative
1283
1302
// var zeroTangentVectorInitializer: () -> TangentVector {
1284
1303
// { [xTanInit = x.zeroTangentVectorInitializer,
1285
1304
// yTanInit = y.zeroTangentVectorInitializer] in
@@ -1289,6 +1308,74 @@ struct Foo<T: Differentiable, U: Differentiable>: @memberwise Differentiable {
1289
1308
}
1290
1309
```
1291
1310
1311
+ ###### Opt out of synthesis for a stored property
1312
+
1313
+ The synthesized implementation of ` Differentiable ` protocol requirements already
1314
+ excludes stored properties that are not differentiable variables, such as stored
1315
+ properties that do not conform to ` Differentiable ` and ` let `
1316
+ properties that do not have a non-mutating ` move(along:) ` . In addition to this
1317
+ behavior, we also introduce a ` @noDerivative ` declaration attribute, which can
1318
+ be attached to properties that the programmer does not wish to include in the
1319
+ synthesized ` Differentiable ` protocol requirement implementation.
1320
+
1321
+ When a stored property is marked with ` @noDerivative ` in a type that declares a
1322
+ conformance to ` Differentiable ` , it will not be treated as a differentiable
1323
+ variable regardless of whether it conforms to ` Differentiable ` . That is, the
1324
+ synthesized implementation of protocol requirements will not include this
1325
+ property.
1326
+
1327
+ ``` swift
1328
+ struct Foo <T : Differentiable , U : Differentiable >: @memberwise Differentiable {
1329
+ // `x` and `y` are the "differentiable variables".
1330
+ var x: T
1331
+ var y: U
1332
+ @noDerivative var customFlag: Bool
1333
+ @noDerivative let helperVariable: T
1334
+ }
1335
+ ```
1336
+
1337
+ For clarity as to which stored properties are to be included for
1338
+ differentiation, the compiler will recommend that all stored properties that
1339
+ cannot be included as differentiable variables (due to either lacking a
1340
+ conformance to ` Differentiable ` or being a non-` class ` -bound ` let ` property) be
1341
+ marked with ` @noDerivative ` . When a property is not included as a differentiable
1342
+ variable and is not marked with ` @noDerivative ` , the compiler produces a warning
1343
+ asking the user to make the exclusion explicit along with fix-it suggestions in
1344
+ IDEs.
1345
+
1346
+ ``` swift
1347
+ struct Foo <T : Differentiable , U : Differentiable >: @memberwise Differentiable {
1348
+ // `x` and `y` are the "differentiable variables".
1349
+ var x: T
1350
+ var y: U
1351
+ var customFlag: Bool
1352
+ let helperVariable: T
1353
+ }
1354
+ ```
1355
+
1356
+ ``` console
1357
+ test.swift:5:4: warning: stored property 'customFlag' has no derivative because 'Bool' does not conform to 'Differentiable'
1358
+ var customFlag: Bool
1359
+
1360
+ test.swift:5:4: note: add a '@noDerivative' attribute to make it explicit
1361
+ var customFlag: Bool
1362
+ ^
1363
+ @noDerivative
1364
+
1365
+ test.swift:6:4: warning: synthesis of the 'Differentiable.move(along:)' requirement for 'Foo' requires all stored properties not marked with `@noDerivative` to be mutable
1366
+ let helperVariable: T
1367
+
1368
+ test.swift:6:4: note: change 'let' to 'var' to make it mutable
1369
+ let helperVariable: T
1370
+ ^~~
1371
+ var
1372
+
1373
+ test.swift:6:4: note: add a '@noDerivative' attribute to make it explicit
1374
+ let helperVariable: T
1375
+ ^
1376
+ @noDerivative
1377
+ ```
1378
+
1292
1379
##### Shortcut synthesis
1293
1380
1294
1381
In certain cases, it is not ideal to keep ` Self ` and ` TangentVector ` as separate
0 commit comments