@@ -183,100 +183,102 @@ to be compatible with differentiation, including:
183
183
computation. By the sum and product rule, this is usually addition. Addition
184
184
is defined on the
185
185
[ ` Numeric ` ] ( https://developer.apple.com/documentation/swift/numeric ) protocol.
186
-
187
-
188
- Many other parts of Swift are library extensible already - for example, any type
189
- can conform to the
190
- [ ` ExpressibleByIntegerLiteral ` ] ( https://developer.apple.com/documentation/swift/expressiblebyintegerliteral )
191
- protocol, which teaches the compiler how to convert an integer literal to that
192
- type. Following this approach, we define a ` RealVectorRepresentable ` protocol,
193
- which declares the four requirements we listed above. The compiler treats any
194
- ` RealVectorRepresentable ` types as supporting differentiation. We make standard
195
- library types such as ` Float ` and `Double** conform to this protocol.
196
-
197
- ** Note:** Currently, arithmetic operators are defined on this protocol because
198
- the standard library does not have a generic ` Arithmetic ` protocol. Although
199
- most arithmetic operators are defined on ` Numeric ` and ` FloatingPoint ` , those
200
- protocols are not designed for aggregate mathematical objects like vectors. We
201
- hope to make a case for a more general arithmetic protocol in the Swift standard
202
- library.
186
+
187
+ Floating point scalars already have properties above, because of the conformance
188
+ to the ` FloatingPoint ` protocol, which inherits from the ` Numeric ` protocol.
189
+ Similarly, we define a ` VectorNumeric ` protocol, which declares the four
190
+ requirements to represent a vector space.
203
191
204
192
``` swift
205
- public protocol RealVectorRepresentable {
206
- associatedtype Scalar : FloatingPoint
207
- associatedtype Dimensionality
208
- init (_ scalar : Scalar)
209
- init (dimensionality : Dimensionality, repeating repeatedValue : Scalar)
210
- func + (lhs : Self , rhs : Self ) -> Self
211
- func - (lhs : Self , rhs : Self ) -> Self
212
- func * (lhs : Self , rhs : Self ) -> Self
213
- func / (lhs : Self , rhs : Self ) -> Self
193
+ public protocol VectorNumeric {
194
+ associatedtype ScalarElement
195
+ associatedtype Dimensionality
196
+ init (_ scalar : ScalarElement)
197
+ init (dimensionality : Dimensionality, repeating repeatedValue : ScalarElement)
198
+ func + (lhs : Self , rhs : Self ) -> Self
199
+ func - (lhs : Self , rhs : Self ) -> Self
200
+ func * (lhs : Self , rhs : Self ) -> Self
214
201
}
215
202
```
216
203
217
- To make a type support differentiation, the user can simply add a conformance to
218
- ` RealVectorRepresentable ` . For example, TensorFlow’s ` Tensor<Scalar> ` type
219
- supports differentiation by conditionally conforming to the
220
- ` RealVectorRepresentable ` protocol when the associated type ` Scalar ` conforms to
221
- ` FloatingPoint ` .
204
+ ` VectorNumeric ` and ` Numeric ` / ` FloatingPoint ` are semantically disjoint. We say
205
+ that a type supports scalar differentiation when it conforms to the
206
+ ` FloatingPoint ` . We say that a type supports ** vector differentiation ** when it
207
+ conforms to ` VectorNumeric ` while its ` ScalarElement ` supports ** scalar
208
+ differentiation ** (i.e. conforms to the ` FloatingPoint ` protocol) .
222
209
210
+ ** Note:** According to the standard library, ` Numeric ` is only suitable for
211
+ scalars, not for aggregate mathematical objects like vectors, and so is
212
+ ` FloatingPoint ` . Today we make ` VectorNumeric ` have duplicate operators, but we
213
+ hope to make a case for a more general numeric protocol in the Swift standard
214
+ library.
215
+
216
+ To make a type support differentiation, the user can simply add a conformance to
217
+ ` FloatingPoint ` or ` VectorNumeric ` . For example, TensorFlow’s ` Tensor<Scalar> `
218
+ type supports differentiation by conditionally conforming to the ` VectorNumeric `
219
+ protocol when the associated type ` Scalar ` conforms to ` FloatingPoint ` .
223
220
224
221
``` swift
225
- extension Tensor : RealVectorRepresentable where Scalar : FloatingPoint {
222
+ extension Tensor : VectorNumeric where Scalar : Numeric {
226
223
typealias Dimensionality = [Int32 ] // This is shape.
224
+ typealias ScalarElement = Scalar
227
225
228
- init (_ scalar : Scalar ) {
229
- self = #tfop (“ Const” , scalar)
226
+ init (_ scalar : ScalarElement ) {
227
+ self = #tfop (" Const" , scalar)
230
228
}
231
229
232
- init (dimensionality : [Int32 ], repeating repeatedValue : Scalar ) {
233
- Self = #tfop (“ Fill” , Tensor (dimensionality), repeatedValue : repeatedValue)
230
+ init (dimensionality : [Int32 ], repeating repeatedValue : ScalarElement ) {
231
+ Self = #tfop (" Fill" , Tensor (dimensionality), repeatedValue : repeatedValue)
234
232
}
233
+
234
+ func + (lhs : Tensor, rhs : Tensor) -> Tensor { ... }
235
+ func - (lhs : Tensor, rhs : Tensor) -> Tensor { ... }
236
+ func * (lhs : Tensor, rhs : Tensor) -> Tensor { ... }
235
237
}
236
238
```
237
239
238
- Since ` RealVectorRepresentable ` is general enough to provide all necessary
239
- ingredients for differentiation and the compiler doesn’t make special
240
- assumptions about well-known types, users can make any type support automatic
241
- differentiation. The following example shows a generic tree structure
242
- ` Tree<Value> ` , written as an algebraic data type, conditionally conforming to
243
- ` RealVectorRepresentable ` by recursively defining operations using pattern
244
- matching. Now, functions over ` Tree<Value> ` can be differentiated!
240
+ Since ` VectorNumeric ` is general enough to provide all necessary ingredients for
241
+ differentiation and the compiler doesn’t make special assumptions about
242
+ well-known types, users can make any type support automatic differentiation. The
243
+ following example shows a generic tree structure ` Tree<Value> ` , written as an
244
+ algebraic data type, conditionally conforming to ` VectorNumeric ` by recursively
245
+ defining operations using pattern matching. Now, functions over ` Tree<Value> `
246
+ can be differentiated!
245
247
246
248
``` swift
247
249
indirect enum Tree <Value > {
248
- case leaf (Value )
249
- case node (Tree, Value , Tree)
250
+ case leaf (Value )
251
+ case node (Tree, Value , Tree)
250
252
}
251
253
252
- extension Tree : RealVectorRepresentable where Value : RealVectorRepresentable {
253
- typealias Scalar = Value .Scalar
254
- typealias Dimensionality = Value .Dimensionality
255
-
256
- init (_ scalar : Scalar) {
257
- self = .leaf (Value (scalar))
258
- }
259
-
260
- init (dimensionality : Dimensionality, repeating repeatedValue : Scalar) {
261
- self = .leaf (Value (dimensionality : dimensionality, repeating : repeatedValue))
262
- }
263
-
264
- static func + (lhs : Tree, rhs : Tree) -> Tree {
265
- switch (lhs, rhs) {
266
- case let (.leaf (x), .leaf (y)):
267
- return .leaf (x + y)
268
- case let (.leaf (x), .node (l, y, r)):
269
- return .node (l, x + y, r)
270
- case let (.node (l, x, r), .leaf (y)):
271
- return .node (l, x + y, r)
272
- case let (.node (l0, x, r0), .node (l1, y, r1)):
273
- return .node (l0 + l1, x + y, r0 + r1)
254
+ extension Tree : VectorNumeric where Value : VectorNumeric {
255
+ typealias ScalarElement = Value .ScalarElement
256
+ typealias Dimensionality = Value .Dimensionality
257
+
258
+ init (_ scalar : ScalarElemenet) {
259
+ self = .leaf (Value (scalar))
274
260
}
275
- }
276
-
277
- static func - (lhs : Tree, rhs : Tree) -> Tree { ... }
278
- static func * (lhs : Tree, rhs : Tree) -> Tree { ... }
279
- static func / (lhs : Tree, rhs : Tree) -> Tree { ... }
261
+
262
+ init (dimensionality : Dimensionality, repeating repeatedValue : ScalarElement) {
263
+ self = .leaf (Value (dimensionality : dimensionality, repeating : repeatedValue))
264
+ }
265
+
266
+ static func + (lhs : Tree, rhs : Tree) -> Tree {
267
+ switch self {
268
+ case let (.leaf (x), .leaf (y)):
269
+ return .leaf (x + y)
270
+ case let (.leaf (x), .node (l, y, r)):
271
+ return .node (l, x + y, r)
272
+ case let (.node (l, x, r), .leaf (y)):
273
+ return .node (l, x + y, r)
274
+ case let (.node (l0, x, r0), .node (l1, y, r1)):
275
+ return .node (l0 + l0, x + y, r0 + r1)
276
+ }
277
+ }
278
+
279
+ static func - (lhs : Tree, rhs : Tree) -> Tree { ... }
280
+ static func * (lhs : Tree, rhs : Tree) -> Tree { ... }
281
+ static func / (lhs : Tree, rhs : Tree) -> Tree { ... }
280
282
}
281
283
```
282
284
@@ -289,8 +291,8 @@ have special knowledge of numeric standard library functions or distinguish
289
291
between primitive operators and other functions. We recursively determine a
290
292
function's differentiability based on:
291
293
292
- * its type signature: whether inputs and the output conform to
293
- ` RealVectorRepresentable `
294
+ * its type signature: whether inputs and the output support scalar
295
+ differentiation or vector differentiation
294
296
* its visibility: if the function body is not visible by the Swift compiler
295
297
(e.g. a C function or an argument which is a closure), then it is not
296
298
differentiable
0 commit comments