@@ -377,10 +377,10 @@ public class AMSGrad<Model: Differentiable & KeyPathIterable>: Optimizer
377
377
}
378
378
}
379
379
380
- /// RAdam Optimizer
380
+ /// RAdam optimizer.
381
381
///
382
- /// Recitified Adam optimizer , a variant of Adam that introduces a term to rectify
383
- /// variance of adaptive learning rate
382
+ /// Rectified Adam, a variant of Adam that introduces a term to rectify the adaptive learning rate
383
+ /// variance.
384
384
///
385
385
/// Reference: ["On the Variance of the Adaptive Learning Rate and Beyond"]
386
386
/// https://arxiv.org/pdf/1908.03265.pdf
@@ -431,24 +431,24 @@ public class RAdam<Model: Differentiable>: Optimizer
431
431
let step = Float ( self . step)
432
432
let beta1Power = pow ( beta1, step)
433
433
let beta2Power = pow ( beta2, step)
434
- // let stepSize = self.learningRate * step / (1 - beta1Power)
435
434
secondMoments = beta2 * secondMoments + direction .* direction * ( 1 - beta2)
436
435
firstMoments = beta1 * firstMoments + direction * ( 1 - beta1)
437
- // Compute maximum length SMA, bias-corrected moving average and approximate length
438
- // SMA
436
+ // Compute maximum length SMA, bias-corrected moving average and approximate length.
439
437
let N_sma_inf = 2 / ( 1 - beta2) - 1
440
- let N_sma_t = N_sma_inf - 2 * step* beta2Power / ( 1 - beta2Power)
438
+ let N_sma_t = N_sma_inf - 2 * step * beta2Power / ( 1 - beta2Power)
441
439
442
440
if N_sma_t > 5 {
443
- // Compute Bias corrected second moments, rectification and adapted momentum
441
+ // Compute bias- corrected second moments, rectification and adapted momentum.
444
442
let secondMoments_h = Model . TangentVector. sqrt ( secondMoments) + epsilon
445
- let stepSize = sqrt ( ( N_sma_t- 4 ) * ( N_sma_t- 2 ) * N_sma_inf/ ( ( N_sma_inf- 4 ) * ( N_sma_inf- 2 ) * ( N_sma_t) ) )
446
- model. move ( along: - stepSize*sqrt( 1 - beta2Power) * firstMoments./ secondMoments_h)
447
- }
448
- else {
449
- // Update with un-adapted momentum
443
+ let stepSize = sqrt (
444
+ ( N_sma_t - 4 ) * ( N_sma_t - 2 ) * N_sma_inf / (
445
+ ( N_sma_inf - 4 ) * ( N_sma_inf - 2 ) * ( N_sma_t)
446
+ ) )
447
+ model. move ( along: - stepSize * sqrt( 1 - beta2Power) * firstMoments ./ secondMoments_h)
448
+ } else {
449
+ // Update with un-adapted momentum.
450
450
let stepSize = self . learningRate * step / ( 1 - beta1Power)
451
- model. move ( along: - stepSize* firstMoments)
451
+ model. move ( along: - stepSize * firstMoments)
452
452
}
453
453
}
454
454
}
0 commit comments