Skip to content
This repository was archived by the owner on Mar 30, 2022. It is now read-only.

Commit a67499a

Browse files
committed
Update VectorNumeric protocol.
1 parent 749bd19 commit a67499a

File tree

1 file changed

+76
-74
lines changed

1 file changed

+76
-74
lines changed

docs/AutomaticDifferentiation.md

Lines changed: 76 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -183,100 +183,102 @@ to be compatible with differentiation, including:
183183
computation. By the sum and product rule, this is usually addition. Addition
184184
is defined on the
185185
[`Numeric`](https://developer.apple.com/documentation/swift/numeric) protocol.
186-
187-
188-
Many other parts of Swift are library extensible already - for example, any type
189-
can conform to the
190-
[`ExpressibleByIntegerLiteral`](https://developer.apple.com/documentation/swift/expressiblebyintegerliteral)
191-
protocol, which teaches the compiler how to convert an integer literal to that
192-
type. Following this approach, we define a `RealVectorRepresentable` protocol,
193-
which declares the four requirements we listed above. The compiler treats any
194-
`RealVectorRepresentable` types as supporting differentiation. We make standard
195-
library types such as `Float` and `Double** conform to this protocol.
196-
197-
**Note:** Currently, arithmetic operators are defined on this protocol because
198-
the standard library does not have a generic `Arithmetic` protocol. Although
199-
most arithmetic operators are defined on `Numeric` and `FloatingPoint`, those
200-
protocols are not designed for aggregate mathematical objects like vectors. We
201-
hope to make a case for a more general arithmetic protocol in the Swift standard
202-
library.
186+
187+
Floating point scalars already have properties above, because of the conformance
188+
to the `FloatingPoint` protocol, which inherits from the `Numeric` protocol.
189+
Similarly, we define a `VectorNumeric` protocol, which declares the four
190+
requirements to represent a vector space.
203191

204192
```swift
205-
public protocol RealVectorRepresentable {
206-
associatedtype Scalar : FloatingPoint
207-
associatedtype Dimensionality
208-
init(_ scalar: Scalar)
209-
init(dimensionality: Dimensionality, repeating repeatedValue: Scalar)
210-
func + (lhs: Self, rhs: Self) -> Self
211-
func - (lhs: Self, rhs: Self) -> Self
212-
func * (lhs: Self, rhs: Self) -> Self
213-
func / (lhs: Self, rhs: Self) -> Self
193+
public protocol VectorNumeric {
194+
associatedtype ScalarElement
195+
associatedtype Dimensionality
196+
init(_ scalar: ScalarElement)
197+
init(dimensionality: Dimensionality, repeating repeatedValue: ScalarElement)
198+
func + (lhs: Self, rhs: Self) -> Self
199+
func - (lhs: Self, rhs: Self) -> Self
200+
func * (lhs: Self, rhs: Self) -> Self
214201
}
215202
```
216203

217-
To make a type support differentiation, the user can simply add a conformance to
218-
`RealVectorRepresentable`. For example, TensorFlow’s `Tensor<Scalar>` type
219-
supports differentiation by conditionally conforming to the
220-
`RealVectorRepresentable` protocol when the associated type `Scalar` conforms to
221-
`FloatingPoint`.
204+
`VectorNumeric` and `Numeric`/`FloatingPoint` are semantically disjoint. We say
205+
that a type supports scalar differentiation when it conforms to the
206+
`FloatingPoint`. We say that a type supports **vector differentiation** when it
207+
conforms to `VectorNumeric` while its `ScalarElement` supports **scalar
208+
differentiation** (i.e. conforms to the `FloatingPoint` protocol).
222209

210+
**Note:** According to the standard library, `Numeric` is only suitable for
211+
scalars, not for aggregate mathematical objects like vectors, and so is
212+
`FloatingPoint`. Today we make `VectorNumeric` have duplicate operators, but we
213+
hope to make a case for a more general numeric protocol in the Swift standard
214+
library.
215+
216+
To make a type support differentiation, the user can simply add a conformance to
217+
`FloatingPoint` or `VectorNumeric`. For example, TensorFlow’s `Tensor<Scalar>`
218+
type supports differentiation by conditionally conforming to the `VectorNumeric`
219+
protocol when the associated type `Scalar` conforms to `FloatingPoint`.
223220

224221
```swift
225-
extension Tensor : RealVectorRepresentable where Scalar : FloatingPoint {
222+
extension Tensor : VectorNumeric where Scalar : Numeric {
226223
typealias Dimensionality = [Int32] // This is shape.
224+
typealias ScalarElement = Scalar
227225

228-
init(_ scalar: Scalar) {
229-
self = #tfop(Const, scalar)
226+
init(_ scalar: ScalarElement) {
227+
self = #tfop("Const", scalar)
230228
}
231229

232-
init(dimensionality: [Int32], repeating repeatedValue: Scalar) {
233-
Self = #tfop(Fill, Tensor(dimensionality), repeatedValue: repeatedValue)
230+
init(dimensionality: [Int32], repeating repeatedValue: ScalarElement) {
231+
Self = #tfop("Fill", Tensor(dimensionality), repeatedValue: repeatedValue)
234232
}
233+
234+
func + (lhs: Tensor, rhs: Tensor) -> Tensor { ... }
235+
func - (lhs: Tensor, rhs: Tensor) -> Tensor { ... }
236+
func * (lhs: Tensor, rhs: Tensor) -> Tensor { ... }
235237
}
236238
```
237239

238-
Since `RealVectorRepresentable` is general enough to provide all necessary
239-
ingredients for differentiation and the compiler doesn’t make special
240-
assumptions about well-known types, users can make any type support automatic
241-
differentiation. The following example shows a generic tree structure
242-
`Tree<Value>`, written as an algebraic data type, conditionally conforming to
243-
`RealVectorRepresentable` by recursively defining operations using pattern
244-
matching. Now, functions over `Tree<Value>` can be differentiated!
240+
Since `VectorNumeric` is general enough to provide all necessary ingredients for
241+
differentiation and the compiler doesn’t make special assumptions about
242+
well-known types, users can make any type support automatic differentiation. The
243+
following example shows a generic tree structure `Tree<Value>`, written as an
244+
algebraic data type, conditionally conforming to `VectorNumeric` by recursively
245+
defining operations using pattern matching. Now, functions over `Tree<Value>`
246+
can be differentiated!
245247

246248
```swift
247249
indirect enum Tree<Value> {
248-
case leaf(Value)
249-
case node(Tree, Value, Tree)
250+
case leaf(Value)
251+
case node(Tree, Value, Tree)
250252
}
251253

252-
extension Tree : RealVectorRepresentable where Value : RealVectorRepresentable {
253-
typealias Scalar = Value.Scalar
254-
typealias Dimensionality = Value.Dimensionality
255-
256-
init(_ scalar: Scalar) {
257-
self = .leaf(Value(scalar))
258-
}
259-
260-
init(dimensionality: Dimensionality, repeating repeatedValue: Scalar) {
261-
self = .leaf(Value(dimensionality: dimensionality, repeating: repeatedValue))
262-
}
263-
264-
static func + (lhs: Tree, rhs: Tree) -> Tree {
265-
switch (lhs, rhs) {
266-
case let (.leaf(x), .leaf(y)):
267-
return .leaf(x + y)
268-
case let (.leaf(x), .node(l, y, r)):
269-
return .node(l, x + y, r)
270-
case let (.node(l, x, r), .leaf(y)):
271-
return .node(l, x + y, r)
272-
case let (.node(l0, x, r0), .node(l1, y, r1)):
273-
return .node(l0 + l1, x + y, r0 + r1)
254+
extension Tree : VectorNumeric where Value : VectorNumeric {
255+
typealias ScalarElement = Value.ScalarElement
256+
typealias Dimensionality = Value.Dimensionality
257+
258+
init(_ scalar: ScalarElemenet) {
259+
self = .leaf(Value(scalar))
274260
}
275-
}
276-
277-
static func - (lhs: Tree, rhs: Tree) -> Tree { ... }
278-
static func * (lhs: Tree, rhs: Tree) -> Tree { ... }
279-
static func / (lhs: Tree, rhs: Tree) -> Tree { ... }
261+
262+
init(dimensionality: Dimensionality, repeating repeatedValue: ScalarElement) {
263+
self = .leaf(Value(dimensionality: dimensionality, repeating: repeatedValue))
264+
}
265+
266+
static func + (lhs: Tree, rhs: Tree) -> Tree {
267+
switch self {
268+
case let (.leaf(x), .leaf(y)):
269+
return .leaf(x + y)
270+
case let (.leaf(x), .node(l, y, r)):
271+
return .node(l, x + y, r)
272+
case let (.node(l, x, r), .leaf(y)):
273+
return .node(l, x + y, r)
274+
case let (.node(l0, x, r0), .node(l1, y, r1)):
275+
return .node(l0 + l0, x + y, r0 + r1)
276+
}
277+
}
278+
279+
static func - (lhs: Tree, rhs: Tree) -> Tree { ... }
280+
static func * (lhs: Tree, rhs: Tree) -> Tree { ... }
281+
static func / (lhs: Tree, rhs: Tree) -> Tree { ... }
280282
}
281283
```
282284

@@ -289,8 +291,8 @@ have special knowledge of numeric standard library functions or distinguish
289291
between primitive operators and other functions. We recursively determine a
290292
function's differentiability based on:
291293

292-
* its type signature: whether inputs and the output conform to
293-
`RealVectorRepresentable`
294+
* its type signature: whether inputs and the output support scalar
295+
differentiation or vector differentiation
294296
* its visibility: if the function body is not visible by the Swift compiler
295297
(e.g. a C function or an argument which is a closure), then it is not
296298
differentiable

0 commit comments

Comments
 (0)