@@ -25,8 +25,7 @@ public protocol Optimizer {
25
25
var learningRate : Scalar { get set }
26
26
/// Updates the specified differentiable variables along the specified
27
27
/// direction.
28
- mutating func update( _ variables: inout Model . AllDifferentiableVariables ,
29
- along direction: Model . TangentVector )
28
+ mutating func update( _ variables: inout Model , along direction: Model . TangentVector )
30
29
}
31
30
32
31
fileprivate extension Tensor where Scalar: Numeric {
@@ -35,14 +34,13 @@ fileprivate extension Tensor where Scalar: Numeric {
35
34
}
36
35
}
37
36
38
- // MARK: - Key-path based optimizers
39
-
40
37
/// Adam optimizer.
41
38
///
42
39
/// Reference: ["Adam - A Method for Stochastic Optimization"](
43
40
/// https://arxiv.org/abs/1412.6980v8)
44
41
public class Adam < Model: Layer > : Optimizer
45
42
where Model. AllDifferentiableVariables == Model . TangentVector {
43
+ public typealias Model = Model
46
44
/// The learning rate.
47
45
public var learningRate : Float
48
46
/// A coefficient used to calculate the first and second moments of
@@ -96,7 +94,7 @@ public class Adam<Model: Layer>: Optimizer
96
94
}
97
95
}
98
96
99
-
97
+ // TODO: Deprecate this when `Differentiable.AllDifferentiableVariables` is removed.
100
98
public func update( _ model: inout Model . AllDifferentiableVariables ,
101
99
along direction: Model . AllDifferentiableVariables ) {
102
100
step += 1
@@ -127,6 +125,11 @@ public class Adam<Model: Layer>: Optimizer
127
125
sqrt( secondMoments [ keyPath: kp] ) + Double( epsilon)
128
126
}
129
127
}
128
+
129
+ public func update( _ model: inout Model ,
130
+ along direction: Model . TangentVector ) {
131
+ update ( & model. allDifferentiableVariables, along: direction)
132
+ }
130
133
}
131
134
132
135
/// RMSProp optimizer.
@@ -139,6 +142,7 @@ public class Adam<Model: Layer>: Optimizer
139
142
/// http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
140
143
public class RMSProp < Model: Layer > : Optimizer
141
144
where Model. AllDifferentiableVariables == Model . TangentVector {
145
+ public typealias Model = Model
142
146
/// The learning rate.
143
147
public var learningRate : Float
144
148
// TODO: Document `rho`. Keras doesn't document `rho`.
@@ -176,7 +180,7 @@ public class RMSProp<Model: Layer>: Optimizer
176
180
}
177
181
}
178
182
179
-
183
+ // TODO: Deprecate this when `Differentiable.AllDifferentiableVariables` is removed.
180
184
public func update( _ model: inout Model . AllDifferentiableVariables ,
181
185
along direction: Model . TangentVector ) {
182
186
step += 1
@@ -195,14 +199,21 @@ public class RMSProp<Model: Layer>: Optimizer
195
199
( sqrt ( alpha [ keyPath: kp] ) + Double( epsilon) )
196
200
}
197
201
}
202
+
203
+ public func update( _ model: inout Model ,
204
+ along direction: Model . TangentVector ) {
205
+ update ( & model. allDifferentiableVariables, along: direction)
206
+ }
198
207
}
199
208
200
209
/// Stochastic gradient descent (SGD) optimizer.
201
210
///
202
211
/// An optimizer that implements stochastic gradient descent, with support for momentum, learning
203
212
/// rate decay, and Nesterov momentum.
204
- public class SGD < Model: Layer > : Optimizer
205
- where Model. AllDifferentiableVariables == Model . TangentVector {
213
+ public class SGD < Model: Differentiable > : Optimizer
214
+ where Model. TangentVector: VectorProtocol & ElementaryFunctions ,
215
+ Model. TangentVector. VectorSpaceScalar == Float {
216
+ public typealias Model = Model
206
217
/// The learning rate.
207
218
public var learningRate : Float
208
219
/// The momentum factor. It accelerates stochastic gradient descent in the relevant direction
@@ -212,8 +223,8 @@ public class SGD<Model: Layer>: Optimizer
212
223
public var decay : Float
213
224
/// Use Nesterov momentum if true.
214
225
public var nesterov : Bool
215
- /// The velocity state of the model
216
- public var velocity : Model . AllDifferentiableVariables
226
+ /// The velocity state of the model.
227
+ public var velocity : Model . TangentVector = . zero
217
228
/// The set of steps taken.
218
229
public var step : Int = 0
219
230
@@ -232,53 +243,38 @@ public class SGD<Model: Layer>: Optimizer
232
243
self . momentum = momentum
233
244
self . decay = decay
234
245
self . nesterov = nesterov
235
- velocity = model. allDifferentiableVariables
236
- for kp in velocity. recursivelyAllWritableKeyPaths ( to: Tensor< Float> . self ) {
237
- velocity [ keyPath: kp] . resetToZero ( )
238
- }
239
- for kp in velocity. recursivelyAllWritableKeyPaths ( to: Tensor< Double> . self ) {
240
- velocity [ keyPath: kp] . resetToZero ( )
241
- }
242
246
}
243
247
248
+ // TODO: Deprecate this when `Differentiable.AllDifferentiableVariables` is removed.
244
249
public func update( _ model: inout Model . AllDifferentiableVariables ,
245
250
along direction: Model . TangentVector ) {
246
251
step += 1
247
252
let learningRate = self . learningRate * 1 / ( 1 + decay * Float( step) )
248
- for kp in model. recursivelyAllWritableKeyPaths ( to: Tensor< Float> . self ) {
249
- velocity [ keyPath: kp] =
250
- momentum * velocity[ keyPath: kp] - learningRate * direction[ keyPath: kp]
251
- if nesterov {
252
- model [ keyPath: kp] +=
253
- momentum * velocity[ keyPath: kp] - learningRate * direction[ keyPath: kp]
254
- } else {
255
- model [ keyPath: kp] += velocity [ keyPath: kp]
256
- }
257
- }
258
- for kp in model. recursivelyAllWritableKeyPaths ( to: Tensor< Double> . self ) {
259
- velocity [ keyPath: kp] =
260
- Double ( momentum) * velocity[ keyPath: kp] -
261
- Double( learningRate) * direction[ keyPath: kp]
262
- if nesterov {
263
- model [ keyPath: kp] +=
264
- Double ( momentum) * velocity[ keyPath: kp] - Double( learningRate) *
265
- direction[ keyPath: kp]
266
- } else {
267
- model [ keyPath: kp] += velocity [ keyPath: kp]
268
- }
253
+ velocity = momentum * velocity - direction * learningRate
254
+ if nesterov {
255
+ model. move ( along: momentum * velocity - direction * learningRate)
256
+ } else {
257
+ model. move ( along: velocity)
269
258
}
270
259
}
260
+
261
+ public func update( _ model: inout Model ,
262
+ along direction: Model . TangentVector ) {
263
+ update ( & model. allDifferentiableVariables, along: direction)
264
+ }
271
265
}
272
266
273
267
// MARK: - Manifold optimizers
274
268
275
269
/// A Riemann manifold stochastic gradient descent (SGD) optimizer.
276
- public class RiemannSGD < Model: Layer , Scalar: FloatingPoint > : Optimizer
277
- where Model. TangentVector: VectorProtocol , Model. TangentVector. VectorSpaceScalar == Scalar {
270
+ public class RiemannSGD < Model: Differentiable > : Optimizer
271
+ where Model. TangentVector: VectorProtocol ,
272
+ Model. TangentVector. VectorSpaceScalar: FloatingPoint {
273
+ public typealias Scalar = Model . TangentVector . VectorSpaceScalar
278
274
/// The learning rate.
279
- public var learningRate : Scalar
275
+ public var learningRate : Model . TangentVector . VectorSpaceScalar
280
276
281
- public init ( learningRate: Scalar ) {
277
+ public init ( learningRate: Model . TangentVector . VectorSpaceScalar ) {
282
278
self . learningRate = learningRate
283
279
}
284
280
@@ -305,6 +301,7 @@ public class RiemannSGD<Model: Layer, Scalar: FloatingPoint>: Optimizer
305
301
///
306
302
public class AdaGrad < Model: Layer > : Optimizer
307
303
where Model. AllDifferentiableVariables == Model . TangentVector {
304
+ public typealias Model = Model
308
305
/// The learning rate.
309
306
public var learningRate : Float
310
307
/// The smoothing factor (ρ). Typical values are `0.5`, `0.9`, and `0.99`, for smoothing over 2,
@@ -337,6 +334,7 @@ public class AdaGrad<Model: Layer>: Optimizer
337
334
}
338
335
}
339
336
337
+ // TODO: Deprecate this when `Differentiable.AllDifferentiableVariables` is removed.
340
338
public func update( _ model: inout Model . AllDifferentiableVariables ,
341
339
along direction: Model . TangentVector ) {
342
340
for kp in model. recursivelyAllWritableKeyPaths ( to: Tensor< Float> . self ) {
@@ -351,4 +349,9 @@ public class AdaGrad<Model: Layer>: Optimizer
351
349
( sqrt ( alpha [ keyPath: kp] + Double( epsilon) ) )
352
350
}
353
351
}
352
+
353
+ public func update( _ model: inout Model ,
354
+ along direction: Model . TangentVector ) {
355
+ update ( & model. allDifferentiableVariables, along: direction)
356
+ }
354
357
}
0 commit comments