Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit efb4596

Browse files
vballolisgugger
authored andcommitted
Add: Rectified Adam optimizer (#564)
* Add: Rectified Adam optimizer * Add RAdam test * Changes to RAdam based on original implementation Removed redundant variables, updated RAdam test and added epsilon to secondMoments_h * Fix RAdam test values * NFC: style changes.
1 parent c2736d4 commit efb4596

File tree

2 files changed

+80
-3
lines changed

2 files changed

+80
-3
lines changed

Sources/TensorFlow/Optimizers/MomentumBased.swift

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,3 +376,79 @@ public class AMSGrad<Model: Differentiable & KeyPathIterable>: Optimizer
376376
model.move(along: -stepSize * firstMoments ./ denominator)
377377
}
378378
}
379+
380+
/// RAdam optimizer.
381+
///
382+
/// Rectified Adam, a variant of Adam that introduces a term to rectify the adaptive learning rate
383+
/// variance.
384+
///
385+
/// Reference: ["On the Variance of the Adaptive Learning Rate and Beyond"]
386+
/// https://arxiv.org/pdf/1908.03265.pdf
387+
public class RAdam<Model: Differentiable>: Optimizer
388+
where Model.TangentVector: VectorProtocol & PointwiseMultiplicative &
389+
ElementaryFunctions & KeyPathIterable,
390+
Model.TangentVector.VectorSpaceScalar == Float {
391+
public typealias Model = Model
392+
/// The learning rate.
393+
public var learningRate: Float
394+
/// A coefficient used to calculate the first and second moments of the gradients.
395+
public var beta1: Float
396+
/// A coefficient used to calculate the first and second moments of the gradients.
397+
public var beta2: Float
398+
/// A small scalar added to the denominator to improve numerical stability.
399+
public var epsilon: Float
400+
/// The learning rate decay.
401+
public var decay: Float
402+
/// The current step.
403+
public var step: Int = 0
404+
/// The first moments of the weights.
405+
public var firstMoments: Model.TangentVector = .zero
406+
/// The second moments of the weights.
407+
public var secondMoments: Model.TangentVector = .zero
408+
409+
public init(
410+
for model: __shared Model,
411+
learningRate: Float = 1e-3,
412+
beta1: Float = 0.9,
413+
beta2: Float = 0.999,
414+
epsilon: Float = 1e-8,
415+
decay: Float = 0
416+
) {
417+
precondition(learningRate >= 0, "Learning rate must be non-negative")
418+
precondition(0 <= beta1 && beta1 <= 1, "Beta parameter must be between 0 and 1")
419+
precondition(0 <= beta2 && beta2 <= 1, "Beta parameter must be between 0 and 1")
420+
precondition(decay >= 0, "Learning rate decay must be non-negative")
421+
422+
self.learningRate = learningRate
423+
self.beta1 = beta1
424+
self.beta2 = beta2
425+
self.epsilon = epsilon
426+
self.decay = decay
427+
}
428+
429+
public func update(_ model: inout Model, along direction: Model.TangentVector) {
430+
step += 1
431+
let step = Float(self.step)
432+
let beta1Power = pow(beta1, step)
433+
let beta2Power = pow(beta2, step)
434+
secondMoments = beta2 * secondMoments + direction .* direction * (1 - beta2)
435+
firstMoments = beta1 * firstMoments + direction * (1 - beta1)
436+
// Compute maximum length SMA, bias-corrected moving average and approximate length.
437+
let N_sma_inf = 2 / (1 - beta2) - 1
438+
let N_sma_t = N_sma_inf - 2 * step * beta2Power / (1 - beta2Power)
439+
440+
if N_sma_t > 5 {
441+
// Compute bias-corrected second moments, rectification and adapted momentum.
442+
let secondMoments_h = Model.TangentVector.sqrt(secondMoments) + epsilon
443+
let stepSize = sqrt(
444+
(N_sma_t - 4) * (N_sma_t - 2) * N_sma_inf / (
445+
(N_sma_inf - 4) * (N_sma_inf - 2) * (N_sma_t)
446+
))
447+
model.move(along: -stepSize * sqrt(1 - beta2Power) * firstMoments ./ secondMoments_h)
448+
} else {
449+
// Update with un-adapted momentum.
450+
let stepSize = self.learningRate * step / (1 - beta1Power)
451+
model.move(along: -stepSize * firstMoments)
452+
}
453+
}
454+
}

Tests/TensorFlowTests/SequentialTests.swift

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ final class SequentialTests: XCTestCase {
4242
let amsgrad = AMSGrad(for: model, learningRate: 0.02)
4343
let adagrad = AdaGrad(for: model, learningRate: 0.02)
4444
let adadelta = AdaDelta(for: model, learningRate: 0.02)
45+
let radam = RAdam(for: model, learningRate: 0.02)
4546
let x: Tensor<Float> = [[0, 0], [0, 1], [1, 0], [1, 1]]
4647
let y: Tensor<Float> = [0, 1, 1, 0]
4748
Context.local.learningPhase = .training
@@ -58,11 +59,11 @@ final class SequentialTests: XCTestCase {
5859
amsgrad.update(&model, along: 𝛁model)
5960
adagrad.update(&model, along: 𝛁model)
6061
adadelta.update(&model, along: 𝛁model)
62+
radam.update(&model, along: 𝛁model)
6163
}
6264
}
63-
assertEqual(model.inferring(from: [[0, 0], [0, 1], [1, 0], [1, 1]]),
64-
[[0.5115531], [0.5115531], [0.5115531], [0.5115531]],
65-
accuracy: 1e-6)
65+
XCTAssertEqual(model.inferring(from: [[0, 0], [0, 1], [1, 0], [1, 1]]),
66+
[[0.5567076], [0.5567076], [0.5567076], [0.5567076]])
6667
}
6768

6869
static var allTests = [

0 commit comments

Comments
 (0)