Skip to content
This repository was archived by the owner on Mar 30, 2022. It is now read-only.

Commit 4f8b5e5

Browse files
committed
Update code example
1 parent a67499a commit 4f8b5e5

File tree

1 file changed

+39
-39
lines changed

1 file changed

+39
-39
lines changed

docs/AutomaticDifferentiation.md

Lines changed: 39 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -166,24 +166,24 @@ to be compatible with differentiation, including:
166166
* The type must represent a arbitrarily ranked vector space (where tensors
167167
live). Elements of this vector space must be floating point numeric. There is
168168
an associated scalar type that is also floating point numeric.
169-
169+
170170
* How to initialize an adjoint value for a parameter from a scalar, with the
171171
same dimensionality as this parameter. This will be used to initialize a zero
172172
derivative when the parameter does not contribute to the output.
173-
173+
174174
* How to initialize a seed value from a value of the scalar type. This will be
175175
used to initialize a differentiation seed - usually `1.0`, which represents
176176
`dy/dy`. Note: the seed type in the adjoint can be an `Optional`, so when
177177
there is no back-propagated adjoint, the value will be set to `nil`. However
178178
this will cause performance issues with TensorFlow’s `Tensor` type today
179179
(optional checks causing send/receive). We need to finish the implementation
180180
of constant expression analysis to be able to fold the optional check away.
181-
181+
182182
* How values of this type will combine at data flow fan-ins in the adjoint
183183
computation. By the sum and product rule, this is usually addition. Addition
184184
is defined on the
185185
[`Numeric`](https://developer.apple.com/documentation/swift/numeric) protocol.
186-
186+
187187
Floating point scalars already have properties above, because of the conformance
188188
to the `FloatingPoint` protocol, which inherits from the `Numeric` protocol.
189189
Similarly, we define a `VectorNumeric` protocol, which declares the four
@@ -220,20 +220,20 @@ protocol when the associated type `Scalar` conforms to `FloatingPoint`.
220220

221221
```swift
222222
extension Tensor : VectorNumeric where Scalar : Numeric {
223-
typealias Dimensionality = [Int32] // This is shape.
224-
typealias ScalarElement = Scalar
223+
typealias Dimensionality = [Int32] // This is shape.
224+
typealias ScalarElement = Scalar
225225

226-
init(_ scalar: ScalarElement) {
227-
self = #tfop("Const", scalar)
228-
}
226+
init(_ scalar: ScalarElement) {
227+
self = #tfop("Const", scalar)
228+
}
229229

230-
init(dimensionality: [Int32], repeating repeatedValue: ScalarElement) {
231-
Self = #tfop("Fill", Tensor(dimensionality), repeatedValue: repeatedValue)
232-
}
230+
init(dimensionality: [Int32], repeating repeatedValue: ScalarElement) {
231+
Self = #tfop("Fill", Tensor(dimensionality), repeatedValue: repeatedValue)
232+
}
233233

234-
func + (lhs: Tensor, rhs: Tensor) -> Tensor { ... }
235-
func - (lhs: Tensor, rhs: Tensor) -> Tensor { ... }
236-
func * (lhs: Tensor, rhs: Tensor) -> Tensor { ... }
234+
func + (lhs: Tensor, rhs: Tensor) -> Tensor { ... }
235+
func - (lhs: Tensor, rhs: Tensor) -> Tensor { ... }
236+
func * (lhs: Tensor, rhs: Tensor) -> Tensor { ... }
237237
}
238238
```
239239

@@ -254,15 +254,15 @@ indirect enum Tree<Value> {
254254
extension Tree : VectorNumeric where Value : VectorNumeric {
255255
typealias ScalarElement = Value.ScalarElement
256256
typealias Dimensionality = Value.Dimensionality
257-
257+
258258
init(_ scalar: ScalarElemenet) {
259259
self = .leaf(Value(scalar))
260260
}
261-
261+
262262
init(dimensionality: Dimensionality, repeating repeatedValue: ScalarElement) {
263263
self = .leaf(Value(dimensionality: dimensionality, repeating: repeatedValue))
264264
}
265-
265+
266266
static func + (lhs: Tree, rhs: Tree) -> Tree {
267267
switch self {
268268
case let (.leaf(x), .leaf(y)):
@@ -275,7 +275,7 @@ extension Tree : VectorNumeric where Value : VectorNumeric {
275275
return .node(l0 + l0, x + y, r0 + r1)
276276
}
277277
}
278-
278+
279279
static func - (lhs: Tree, rhs: Tree) -> Tree { ... }
280280
static func * (lhs: Tree, rhs: Tree) -> Tree { ... }
281281
static func / (lhs: Tree, rhs: Tree) -> Tree { ... }
@@ -320,15 +320,15 @@ differentiability.
320320
// The corresponding adjoint to call is `dTanh`.
321321
@differentiable(reverse, adjoint: dTanh)
322322
func tanh(_ x: Float) -> Float {
323-
... some super low-level assembly tanh implementation ...
323+
... some super low-level assembly tanh implementation ...
324324
}
325325
// d/dx tanh(x) = 1 - (tanh(x))^2
326326
//
327327
// Here, y is the original result of tanh(x), and x is the input parameter of the
328328
// original function. We don't need to use `x` in tanh's adjoint because we already
329329
// have access to the original result.
330330
func dTanh(x: Float, y: Float, seed: Float) -> Float {
331-
return (1.0 - y * y) * seed
331+
return (1.0 - y * y) * seed
332332
}
333333
```
334334

@@ -344,18 +344,18 @@ defined as instance methods, e.g. `FloatingPoint.squareRoot()` and
344344

345345
```swift
346346
extension Tensor {
347-
// Differentiable with respect to `self` (the input) and the first parameter
348-
// (the filter) using reverse-mode AD. The corresponding adjoint to call
349-
// is `dConv`
350-
@differentiable(reverse, withRespectTo: (self, .0), adjoint: dConv)
351-
func convolved(withFilter k: Tensor, strides: [Int32], padding: Padding) -> Tensor {
352-
return #tfop("Conv2D", ...)
353-
}
354-
355-
func dConv(k: Tensor, strides: [Int32], padding: Padding,
356-
y: Tensor, seed: Tensor) -> Tensor {
357-
...
358-
}
347+
// Differentiable with respect to `self` (the input) and the first parameter
348+
// (the filter) using reverse-mode AD. The corresponding adjoint to call
349+
// is `dConv`
350+
@differentiable(reverse, withRespectTo: (self, .0), adjoint: dConv)
351+
func convolved(withFilter k: Tensor, strides: [Int32], padding: Padding) -> Tensor {
352+
return #tfop("Conv2D", ...)
353+
}
354+
355+
func dConv(k: Tensor, strides: [Int32], padding: Padding,
356+
y: Tensor, seed: Tensor) -> Tensor {
357+
...
358+
}
359359
}
360360
```
361361

@@ -373,15 +373,15 @@ A trivial example is shown as follows:
373373
```swift
374374
@differentiable(reverse, adjoint: dTanh)
375375
func tanh(_ x: Float) -> Float {
376-
... some super low-level assembly tanh implementation ...
376+
... some super low-level assembly tanh implementation ...
377377
}
378378

379379
func dTanh(x: Float, y: Float, seed: Float) -> Float {
380-
return (1.0 - (y * y)) * seed
380+
return (1.0 - (y * y)) * seed
381381
}
382382

383383
func foo(_ x: Float, _ y: Float) -> Float {
384-
return tanh(x) + tanh(y)
384+
return tanh(x) + tanh(y)
385385
}
386386

387387
// Get the gradient function of tanh.
@@ -442,7 +442,7 @@ the vector-Jacobian products.
442442
which internally calls `f_can_grad` using a default seed `1` and throws away
443443
the first result (the first result would be used if `#valueAndGradient(of:)`
444444
was the differential operator).
445-
445+
446446
More than one function exists to wrap the canonical gradient function
447447
`f_can_grad`, because we'll support a variety of AD configurations, e.g.
448448
`#gradient(of:)` and `#valueAndGradient(of:)`. We expect the finalized gradient
@@ -514,8 +514,8 @@ confusion](https://arxiv.org/abs/1211.4892) are two common bugs in nested uses
514514
of the differential operator using SCT techniques, and require user attention to
515515
correctly resolve. The application of rank-2 polymorphism in the
516516
[ad](https://hackage.haskell.org/package/ad) package in Haskell defined away
517-
sensitivity confusion, but Swift’s type system does not support that today. In
518-
order to support higher-order differentiation with sound semantics and predictable
517+
sensitivity confusion, but Swift’s type system does not support that today. In
518+
order to support higher-order differentiation with sound semantics and predictable
519519
behavior in Swift, we need to teach the compiler to carefully emit diagnostics and
520520
reject malformed cases.
521521

0 commit comments

Comments
 (0)