Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 5bde6f8

Browse files
committed
Add: Rectified Adam optimizer
1 parent b8ae57e commit 5bde6f8

File tree

1 file changed

+82
-0
lines changed

1 file changed

+82
-0
lines changed

Sources/TensorFlow/Optimizers/MomentumBased.swift

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,3 +376,85 @@ public class AMSGrad<Model: Differentiable & KeyPathIterable>: Optimizer
376376
model.move(along: -stepSize * firstMoments ./ denominator)
377377
}
378378
}
379+
380+
/// RAdam Optimizer
381+
///
382+
/// Recitified Adam optimizer, a variant of Adam that introduces a term to rectify
383+
/// variance of adaptive learning rate
384+
///
385+
/// Reference: ["On the Variance of the Adaptive Learning Rate and Beyond"]
386+
/// https://arxiv.org/pdf/1908.03265.pdf
387+
public class RAdam<Model: Differentiable>: Optimizer
388+
where Model.TangentVector: VectorProtocol & PointwiseMultiplicative &
389+
ElementaryFunctions & KeyPathIterable,
390+
Model.TangentVector.VectorSpaceScalar == Float {
391+
public typealias Model = Model
392+
/// The learning rate.
393+
public var learningRate: Float
394+
/// A coefficient used to calculate the first and second moments of the gradients.
395+
public var beta1: Float
396+
/// A coefficient used to calculate the first and second moments of the gradients.
397+
public var beta2: Float
398+
/// A small scalar added to the denominator to improve numerical stability.
399+
public var epsilon: Float
400+
/// The learning rate decay.
401+
public var decay: Float
402+
/// The current step.
403+
public var step: Int = 0
404+
/// The first moments of the weights.
405+
public var firstMoments: Model.TangentVector = .zero
406+
/// The second moments of the weights.
407+
public var secondMoments: Model.TangentVector = .zero
408+
409+
public var firstMoments_h: Model.TangentVector = .zero
410+
411+
public var secondMoments_h: Model.TangentVector = .zero
412+
413+
public init(
414+
for model: __shared Model,
415+
learningRate: Float = 1e-3,
416+
beta1: Float = 0.9,
417+
beta2: Float = 0.999,
418+
epsilon: Float = 1e-8,
419+
decay: Float = 0
420+
) {
421+
precondition(learningRate >= 0, "Learning rate must be non-negative")
422+
precondition(0 <= beta1 && beta1 <= 1, "Beta parameter must be between 0 and 1")
423+
precondition(0 <= beta2 && beta2 <= 1, "Beta parameter must be between 0 and 1")
424+
precondition(decay >= 0, "Learning rate decay must be non-negative")
425+
426+
self.learningRate = learningRate
427+
self.beta1 = beta1
428+
self.beta2 = beta2
429+
self.epsilon = epsilon
430+
self.decay = decay
431+
}
432+
433+
public func update(_ model: inout Model, along direction: Model.TangentVector) {
434+
step += 1
435+
let step = Float(self.step)
436+
let beta1Power = pow(beta1, step)
437+
let beta2Power = pow(beta2, step)
438+
let stepSize = self.learningRate * step / (1 - beta1Power)
439+
secondMoments = beta2 * secondMoments + direction .* direction * (1 - beta2)
440+
firstMoments = beta1 * firstMoments + direction * (1 - beta1)
441+
442+
// Compute maximum length SMA, bias-corrected moving average and approximate length
443+
// SMA
444+
let N_sma_inf = 2 / (1 - beta2) - 1
445+
let N_sma_t = N_sma_inf - 2*step*beta2Power / (1 - beta2Power)
446+
firstMoments_h = firstMoments
447+
448+
if N_sma_t > 4 {
449+
// Comppute Bias corrected second moments, rectification and
450+
// adapted momentum
451+
secondMoments_h = Model.TangentVector.sqrt(secondMoments)
452+
let r = sqrt((N_sma_t-4)*(N_sma_t-2)*N_sma_inf/((N_sma_inf-4)*(N_sma_inf-2)*(N_sma_t)))
453+
model.move(along: -stepSize*sqrt(1 - beta2Power)*firstMoments_h*r./secondMoments_h)
454+
}
455+
else {
456+
// Update with un-adapted momentum
457+
model.move(along: -stepSize*firstMoments_h)
458+
}
459+
}
460+
}

0 commit comments

Comments
 (0)