@@ -376,3 +376,85 @@ public class AMSGrad<Model: Differentiable & KeyPathIterable>: Optimizer
376
376
model. move ( along: - stepSize * firstMoments ./ denominator)
377
377
}
378
378
}
379
+
380
+ /// RAdam Optimizer
381
+ ///
382
+ /// Recitified Adam optimizer, a variant of Adam that introduces a term to rectify
383
+ /// variance of adaptive learning rate
384
+ ///
385
+ /// Reference: ["On the Variance of the Adaptive Learning Rate and Beyond"]
386
+ /// https://arxiv.org/pdf/1908.03265.pdf
387
+ public class RAdam < Model: Differentiable > : Optimizer
388
+ where Model. TangentVector: VectorProtocol & PointwiseMultiplicative &
389
+ ElementaryFunctions & KeyPathIterable ,
390
+ Model. TangentVector. VectorSpaceScalar == Float {
391
+ public typealias Model = Model
392
+ /// The learning rate.
393
+ public var learningRate : Float
394
+ /// A coefficient used to calculate the first and second moments of the gradients.
395
+ public var beta1 : Float
396
+ /// A coefficient used to calculate the first and second moments of the gradients.
397
+ public var beta2 : Float
398
+ /// A small scalar added to the denominator to improve numerical stability.
399
+ public var epsilon : Float
400
+ /// The learning rate decay.
401
+ public var decay : Float
402
+ /// The current step.
403
+ public var step : Int = 0
404
+ /// The first moments of the weights.
405
+ public var firstMoments : Model . TangentVector = . zero
406
+ /// The second moments of the weights.
407
+ public var secondMoments : Model . TangentVector = . zero
408
+
409
+ public var firstMoments_h : Model . TangentVector = . zero
410
+
411
+ public var secondMoments_h : Model . TangentVector = . zero
412
+
413
+ public init (
414
+ for model: __shared Model,
415
+ learningRate: Float = 1e-3 ,
416
+ beta1: Float = 0.9 ,
417
+ beta2: Float = 0.999 ,
418
+ epsilon: Float = 1e-8 ,
419
+ decay: Float = 0
420
+ ) {
421
+ precondition ( learningRate >= 0 , " Learning rate must be non-negative " )
422
+ precondition ( 0 <= beta1 && beta1 <= 1 , " Beta parameter must be between 0 and 1 " )
423
+ precondition ( 0 <= beta2 && beta2 <= 1 , " Beta parameter must be between 0 and 1 " )
424
+ precondition ( decay >= 0 , " Learning rate decay must be non-negative " )
425
+
426
+ self . learningRate = learningRate
427
+ self . beta1 = beta1
428
+ self . beta2 = beta2
429
+ self . epsilon = epsilon
430
+ self . decay = decay
431
+ }
432
+
433
+ public func update( _ model: inout Model , along direction: Model . TangentVector ) {
434
+ step += 1
435
+ let step = Float ( self . step)
436
+ let beta1Power = pow ( beta1, step)
437
+ let beta2Power = pow ( beta2, step)
438
+ let stepSize = self . learningRate * step / ( 1 - beta1Power)
439
+ secondMoments = beta2 * secondMoments + direction .* direction * ( 1 - beta2)
440
+ firstMoments = beta1 * firstMoments + direction * ( 1 - beta1)
441
+
442
+ // Compute maximum length SMA, bias-corrected moving average and approximate length
443
+ // SMA
444
+ let N_sma_inf = 2 / ( 1 - beta2) - 1
445
+ let N_sma_t = N_sma_inf - 2 * step*beta2Power / ( 1 - beta2Power)
446
+ firstMoments_h = firstMoments
447
+
448
+ if N_sma_t > 4 {
449
+ // Comppute Bias corrected second moments, rectification and
450
+ // adapted momentum
451
+ secondMoments_h = Model . TangentVector. sqrt ( secondMoments)
452
+ let r = sqrt ( ( N_sma_t- 4 ) * ( N_sma_t- 2 ) * N_sma_inf/ ( ( N_sma_inf- 4 ) * ( N_sma_inf- 2 ) * ( N_sma_t) ) )
453
+ model. move ( along: - stepSize*sqrt( 1 - beta2Power) * firstMoments_h*r./ secondMoments_h)
454
+ }
455
+ else {
456
+ // Update with un-adapted momentum
457
+ model. move ( along: - stepSize*firstMoments_h)
458
+ }
459
+ }
460
+ }
0 commit comments