Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Add: Rectified Adam optimizer #564

Merged
merged 6 commits into from
Dec 6, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 76 additions & 0 deletions Sources/TensorFlow/Optimizers/MomentumBased.swift
Original file line number Diff line number Diff line change
Expand Up @@ -376,3 +376,79 @@ public class AMSGrad<Model: Differentiable & KeyPathIterable>: Optimizer
model.move(along: -stepSize * firstMoments ./ denominator)
}
}

/// RAdam optimizer.
///
/// Rectified Adam, a variant of Adam that introduces a term to rectify the adaptive learning rate
/// variance.
///
/// Reference: ["On the Variance of the Adaptive Learning Rate and Beyond"]
/// https://arxiv.org/pdf/1908.03265.pdf
public class RAdam<Model: Differentiable>: Optimizer
where Model.TangentVector: VectorProtocol & PointwiseMultiplicative &
ElementaryFunctions & KeyPathIterable,
Model.TangentVector.VectorSpaceScalar == Float {
public typealias Model = Model
/// The learning rate.
public var learningRate: Float
/// A coefficient used to calculate the first and second moments of the gradients.
public var beta1: Float
/// A coefficient used to calculate the first and second moments of the gradients.
public var beta2: Float
/// A small scalar added to the denominator to improve numerical stability.
public var epsilon: Float
/// The learning rate decay.
public var decay: Float
/// The current step.
public var step: Int = 0
/// The first moments of the weights.
public var firstMoments: Model.TangentVector = .zero
/// The second moments of the weights.
public var secondMoments: Model.TangentVector = .zero

public init(
for model: __shared Model,
learningRate: Float = 1e-3,
beta1: Float = 0.9,
beta2: Float = 0.999,
epsilon: Float = 1e-8,
decay: Float = 0
) {
precondition(learningRate >= 0, "Learning rate must be non-negative")
precondition(0 <= beta1 && beta1 <= 1, "Beta parameter must be between 0 and 1")
precondition(0 <= beta2 && beta2 <= 1, "Beta parameter must be between 0 and 1")
precondition(decay >= 0, "Learning rate decay must be non-negative")

self.learningRate = learningRate
self.beta1 = beta1
self.beta2 = beta2
self.epsilon = epsilon
self.decay = decay
}

public func update(_ model: inout Model, along direction: Model.TangentVector) {
step += 1
let step = Float(self.step)
let beta1Power = pow(beta1, step)
let beta2Power = pow(beta2, step)
secondMoments = beta2 * secondMoments + direction .* direction * (1 - beta2)
firstMoments = beta1 * firstMoments + direction * (1 - beta1)
// Compute maximum length SMA, bias-corrected moving average and approximate length.
let N_sma_inf = 2 / (1 - beta2) - 1
let N_sma_t = N_sma_inf - 2 * step * beta2Power / (1 - beta2Power)

if N_sma_t > 5 {
// Compute bias-corrected second moments, rectification and adapted momentum.
let secondMoments_h = Model.TangentVector.sqrt(secondMoments) + epsilon
let stepSize = sqrt(
(N_sma_t - 4) * (N_sma_t - 2) * N_sma_inf / (
(N_sma_inf - 4) * (N_sma_inf - 2) * (N_sma_t)
))
model.move(along: -stepSize * sqrt(1 - beta2Power) * firstMoments ./ secondMoments_h)
} else {
// Update with un-adapted momentum.
let stepSize = self.learningRate * step / (1 - beta1Power)
model.move(along: -stepSize * firstMoments)
}
}
}
7 changes: 4 additions & 3 deletions Tests/TensorFlowTests/SequentialTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ final class SequentialTests: XCTestCase {
let amsgrad = AMSGrad(for: model, learningRate: 0.02)
let adagrad = AdaGrad(for: model, learningRate: 0.02)
let adadelta = AdaDelta(for: model, learningRate: 0.02)
let radam = RAdam(for: model, learningRate: 0.02)
let x: Tensor<Float> = [[0, 0], [0, 1], [1, 0], [1, 1]]
let y: Tensor<Float> = [0, 1, 1, 0]
Context.local.learningPhase = .training
Expand All @@ -58,11 +59,11 @@ final class SequentialTests: XCTestCase {
amsgrad.update(&model, along: 𝛁model)
adagrad.update(&model, along: 𝛁model)
adadelta.update(&model, along: 𝛁model)
radam.update(&model, along: 𝛁model)
}
}
assertEqual(model.inferring(from: [[0, 0], [0, 1], [1, 0], [1, 1]]),
[[0.5115531], [0.5115531], [0.5115531], [0.5115531]],
accuracy: 1e-6)
XCTAssertEqual(model.inferring(from: [[0, 0], [0, 1], [1, 0], [1, 1]]),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if you ran this test locally, after adding RAdam? The expected results of model.inferring should be updated.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, do check if build and test pass locally.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test did pass locally. Will run it again and get back to you.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test is failing now - result's going to NaN. Anyway I can debug this ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test is failing now - result's going to NaN. Anyway I can debug this ?

I'd recommend writing a standalone unit test for rectified Adam. This will help isolate the issue.

Here's a simple reference test for rectified Adam - would you like to port it to a new file Tests/TensorFlow/OptimizerTests.swift?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I will be porting it to a new file like you've suggested. I've found the issue, will make the changes and submit OptimizerTests as a separate PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like the test is still failing:

XCTAssertEqual failed: ("[[0.5567076],
 [0.5567076],
 [0.5567076],
 [0.5567076]]") is not equal to ("[[0.5053005],
 [0.5053005],
 [0.5053005],
 [0.5053005]]")

[[0.5567076], [0.5567076], [0.5567076], [0.5567076]])
}

static var allTests = [
Expand Down