@@ -376,3 +376,79 @@ public class AMSGrad<Model: Differentiable & KeyPathIterable>: Optimizer
376
376
model. move ( along: - stepSize * firstMoments ./ denominator)
377
377
}
378
378
}
379
+
380
+ /// RAdam optimizer.
381
+ ///
382
+ /// Rectified Adam, a variant of Adam that introduces a term to rectify the adaptive learning rate
383
+ /// variance.
384
+ ///
385
+ /// Reference: ["On the Variance of the Adaptive Learning Rate and Beyond"]
386
+ /// https://arxiv.org/pdf/1908.03265.pdf
387
+ public class RAdam < Model: Differentiable > : Optimizer
388
+ where Model. TangentVector: VectorProtocol & PointwiseMultiplicative &
389
+ ElementaryFunctions & KeyPathIterable ,
390
+ Model. TangentVector. VectorSpaceScalar == Float {
391
+ public typealias Model = Model
392
+ /// The learning rate.
393
+ public var learningRate : Float
394
+ /// A coefficient used to calculate the first and second moments of the gradients.
395
+ public var beta1 : Float
396
+ /// A coefficient used to calculate the first and second moments of the gradients.
397
+ public var beta2 : Float
398
+ /// A small scalar added to the denominator to improve numerical stability.
399
+ public var epsilon : Float
400
+ /// The learning rate decay.
401
+ public var decay : Float
402
+ /// The current step.
403
+ public var step : Int = 0
404
+ /// The first moments of the weights.
405
+ public var firstMoments : Model . TangentVector = . zero
406
+ /// The second moments of the weights.
407
+ public var secondMoments : Model . TangentVector = . zero
408
+
409
+ public init (
410
+ for model: __shared Model,
411
+ learningRate: Float = 1e-3 ,
412
+ beta1: Float = 0.9 ,
413
+ beta2: Float = 0.999 ,
414
+ epsilon: Float = 1e-8 ,
415
+ decay: Float = 0
416
+ ) {
417
+ precondition ( learningRate >= 0 , " Learning rate must be non-negative " )
418
+ precondition ( 0 <= beta1 && beta1 <= 1 , " Beta parameter must be between 0 and 1 " )
419
+ precondition ( 0 <= beta2 && beta2 <= 1 , " Beta parameter must be between 0 and 1 " )
420
+ precondition ( decay >= 0 , " Learning rate decay must be non-negative " )
421
+
422
+ self . learningRate = learningRate
423
+ self . beta1 = beta1
424
+ self . beta2 = beta2
425
+ self . epsilon = epsilon
426
+ self . decay = decay
427
+ }
428
+
429
+ public func update( _ model: inout Model , along direction: Model . TangentVector ) {
430
+ step += 1
431
+ let step = Float ( self . step)
432
+ let beta1Power = pow ( beta1, step)
433
+ let beta2Power = pow ( beta2, step)
434
+ secondMoments = beta2 * secondMoments + direction .* direction * ( 1 - beta2)
435
+ firstMoments = beta1 * firstMoments + direction * ( 1 - beta1)
436
+ // Compute maximum length SMA, bias-corrected moving average and approximate length.
437
+ let N_sma_inf = 2 / ( 1 - beta2) - 1
438
+ let N_sma_t = N_sma_inf - 2 * step * beta2Power / ( 1 - beta2Power)
439
+
440
+ if N_sma_t > 5 {
441
+ // Compute bias-corrected second moments, rectification and adapted momentum.
442
+ let secondMoments_h = Model . TangentVector. sqrt ( secondMoments) + epsilon
443
+ let stepSize = sqrt (
444
+ ( N_sma_t - 4 ) * ( N_sma_t - 2 ) * N_sma_inf / (
445
+ ( N_sma_inf - 4 ) * ( N_sma_inf - 2 ) * ( N_sma_t)
446
+ ) )
447
+ model. move ( along: - stepSize * sqrt( 1 - beta2Power) * firstMoments ./ secondMoments_h)
448
+ } else {
449
+ // Update with un-adapted momentum.
450
+ let stepSize = self . learningRate * step / ( 1 - beta1Power)
451
+ model. move ( along: - stepSize * firstMoments)
452
+ }
453
+ }
454
+ }
0 commit comments