Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit da43ba3

Browse files
committed
Changes to RAdam based on original implementation
Removed redundant variables, updated RAdam test and added epsilon to secondMoments_h
1 parent 3c22d78 commit da43ba3

File tree

2 files changed

+11
-17
lines changed

2 files changed

+11
-17
lines changed

Sources/TensorFlow/Optimizers/MomentumBased.swift

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -406,10 +406,6 @@ public class RAdam<Model: Differentiable>: Optimizer
406406
/// The second moments of the weights.
407407
public var secondMoments: Model.TangentVector = .zero
408408

409-
public var firstMoments_h: Model.TangentVector = .zero
410-
411-
public var secondMoments_h: Model.TangentVector = .zero
412-
413409
public init(
414410
for model: __shared Model,
415411
learningRate: Float = 1e-3,
@@ -435,26 +431,24 @@ public class RAdam<Model: Differentiable>: Optimizer
435431
let step = Float(self.step)
436432
let beta1Power = pow(beta1, step)
437433
let beta2Power = pow(beta2, step)
438-
let stepSize = self.learningRate * step / (1 - beta1Power)
434+
// let stepSize = self.learningRate * step / (1 - beta1Power)
439435
secondMoments = beta2 * secondMoments + direction .* direction * (1 - beta2)
440436
firstMoments = beta1 * firstMoments + direction * (1 - beta1)
441-
442437
// Compute maximum length SMA, bias-corrected moving average and approximate length
443438
// SMA
444439
let N_sma_inf = 2 / (1 - beta2) - 1
445440
let N_sma_t = N_sma_inf - 2*step*beta2Power / (1 - beta2Power)
446-
firstMoments_h = firstMoments
447-
448-
if N_sma_t > 4 {
449-
// Comppute Bias corrected second moments, rectification and
450-
// adapted momentum
451-
secondMoments_h = Model.TangentVector.sqrt(secondMoments)
452-
let r = sqrt((N_sma_t-4)*(N_sma_t-2)*N_sma_inf/((N_sma_inf-4)*(N_sma_inf-2)*(N_sma_t)))
453-
model.move(along: -stepSize*sqrt(1 - beta2Power)*firstMoments_h*r./secondMoments_h)
441+
442+
if N_sma_t > 5 {
443+
// Compute Bias corrected second moments, rectification and adapted momentum
444+
let secondMoments_h = Model.TangentVector.sqrt(secondMoments) + epsilon
445+
let stepSize = sqrt((N_sma_t-4)*(N_sma_t-2)*N_sma_inf/((N_sma_inf-4)*(N_sma_inf-2)*(N_sma_t)))
446+
model.move(along: -stepSize*sqrt(1 - beta2Power)*firstMoments./secondMoments_h)
454447
}
455448
else {
456449
// Update with un-adapted momentum
457-
model.move(along: -stepSize*firstMoments_h)
450+
let stepSize = self.learningRate * step / (1 - beta1Power)
451+
model.move(along: -stepSize*firstMoments)
458452
}
459453
}
460-
}
454+
}

Tests/TensorFlowTests/SequentialTests.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ final class SequentialTests: XCTestCase {
6363
}
6464
}
6565
XCTAssertEqual(model.inferring(from: [[0, 0], [0, 1], [1, 0], [1, 1]]),
66-
[[0.4884567], [0.4884567], [0.4884567], [0.4884567]])
66+
[[0.5053005], [0.5053005], [0.5053005], [0.5053005]])
6767
}
6868

6969
static var allTests = [

0 commit comments

Comments
 (0)