Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Add AdaMax optimizer #243

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 98 additions & 0 deletions Sources/TensorFlow/Optimizer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,104 @@ public class Adam<Model: Layer>: Optimizer
}
}

/// AdaMax optimizer.
///
/// A variant of Adam based on the infinity-norm.
///
/// Reference: Section 7 of ["Adam - A Method for Stochastic Optimization"](
/// https://arxiv.org/abs/1412.6980v8)
public class AdaMax<Model: Layer>: Optimizer
where Model.AllDifferentiableVariables == Model.TangentVector {
/// The learning rate.
public var learningRate: Float
/// Decay rate used to estimate the first moment (mean) of gradients.
public var beta1: Float
/// Decay rate used to estimate the exponentially weighted infinity norm.
public var beta2: Float
/// A small scalar added to the denominator to improve numerical stability.
public var epsilon: Float
/// The learning rate decay.
public var decay: Float
/// The step count.
public var step: Int = 0
/// The first moments of the weights.
public var firstMoments: Model.AllDifferentiableVariables
/// The infinity norm of the weights.
public var infinityNorm: Model.AllDifferentiableVariables

public init(
for model: __shared Model,
learningRate: Float = 1e-3,
beta1: Float = 0.9,
beta2: Float = 0.999,
epsilon: Float = 1e-8,
decay: Float = 0
) {
precondition(learningRate >= 0, "Learning rate must be non-negative")
precondition(0 <= beta1 && beta1 <= 1, "Beta parameter must be between 0 and 1")
precondition(0 <= beta2 && beta2 <= 1, "Beta parameter must be between 0 and 1")
precondition(decay >= 0, "Learning rate decay must be non-negative")

self.learningRate = learningRate
self.beta1 = beta1
self.beta2 = beta2
self.epsilon = epsilon
self.decay = decay

// Initialize first & second moments to be zeros of the same shape.
// We can't use `Model.AllDifferentiableVariables.zero` due to the
// interaction between Key Paths and Differentiable Arrays.
firstMoments = model.allDifferentiableVariables
infinityNorm = model.allDifferentiableVariables
for kp in firstMoments.recursivelyAllWritableKeyPaths(to: Tensor<Float>.self) {
firstMoments[keyPath: kp].resetToZero()
infinityNorm[keyPath: kp].resetToZero()
}
for kp in firstMoments.recursivelyAllWritableKeyPaths(to: Tensor<Double>.self) {
firstMoments[keyPath: kp].resetToZero()
infinityNorm[keyPath: kp].resetToZero()
}
}

public func update(_ model: inout Model.AllDifferentiableVariables,
along direction: Model.AllDifferentiableVariables) {
step += 1
let learningRate = self.learningRate * 1 / (1 + decay * Float(step))
// Note: `stepSize` is split into two lines to avoid the "compiler is unable to type-check
// this expression in reasonable time" error.
var stepSize = learningRate * sqrt(1 - pow(beta2, Float(step)))
stepSize = stepSize / (1 - pow(beta1, Float(step)))
// Update Float & Double Tensor variables.
for kp in model.recursivelyAllWritableKeyPaths(to: Tensor<Float>.self) {
// Update biased first moment estimate.
firstMoments[keyPath: kp] =
beta1 * firstMoments[keyPath: kp] + (1 - beta1) * direction[keyPath: kp]
// Update the exponentially weighted infinity norm.
infinityNorm[keyPath: kp] =
max(beta2 * infinityNorm[keyPath: kp], abs(direction[keyPath: kp]))
// Update model paramaters.
let biasCorrection = stepSize / (1 - pow(beta1, Float(step)))
model[keyPath: kp] -=
biasCorrection * firstMoments[keyPath: kp]
/ (infinityNorm[keyPath: kp] + Float(self.epsilon))
}
for kp in model.recursivelyAllWritableKeyPaths(to: Tensor<Double>.self) {
// Update biased first moment estimate.
firstMoments[keyPath: kp] =
Double(beta1) * firstMoments[keyPath: kp]
+ Double(1 - beta2) * direction[keyPath: kp]
// Update the exponentially weighted infinity norm.
infinityNorm[keyPath: kp] =
max(Double(beta2) * infinityNorm[keyPath: kp], abs(direction[keyPath: kp]))
// Update model paramaters.
let biasCorrection = Double(stepSize) / Double(1 - pow(beta1, Float(step)))
model[keyPath: kp] -=
biasCorrection * firstMoments[keyPath: kp]
/ (infinityNorm[keyPath: kp] + Double(self.epsilon))
}
}
}

/// RMSProp optimizer.
///
/// It is recommended to leave the parameters of this optimizer at their default values (except the
Expand Down
49 changes: 49 additions & 0 deletions Tests/TensorFlowTests/OptimizersTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,57 @@ final class OptimizerTests: XCTestCase {
XCTAssertEqual(round(ŷ), y)
}

func testAdaMax() {
struct Classifier: Layer {
var l1, l2: Dense<Float>
init(hiddenSize: Int) {
l1 = Dense<Float>(
inputSize: 2,
outputSize: hiddenSize,
activation: relu,
seed: (0xfffffff, 0xfeeff)
)
l2 = Dense<Float>(
inputSize: hiddenSize,
outputSize: 1,
activation: relu,
seed: (0xffeffe, 0xfffe)
)
}
@differentiable
func callAsFunction(_ input: Tensor<Float>) -> Tensor<Float> {
let h1 = l1(input)
return l2(h1)
}
}

var classifier = Classifier(hiddenSize: 4)
let optimizer = AdaMax(for: classifier, learningRate: 0.02)
let x: Tensor<Float> = [[0, 0], [0, 1], [1, 0], [1, 1]]
let y: Tensor<Float> = [[0], [1], [1], [0]]

Context.local.learningPhase = .training

// untrained classifier should not return valid values
var ŷ = classifier.inferring(from: x)
XCTAssertNotEqual(round(ŷ), y)

for _ in 0..<400 {
let 𝛁model = classifier.gradient { classifier -> Tensor<Float> in
let ŷ = classifier(x)
return meanSquaredError(predicted: ŷ, expected: y)
}
optimizer.update(&classifier.allDifferentiableVariables, along: 𝛁model)
}

// trained classifier should return valid values
ŷ = classifier.inferring(from: x)
XCTAssertEqual(round(ŷ), y)
}


static var allTests = [
("testAdaGrad", testAdaGrad),
("testAdaMax", testAdaMax),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
("testAdaMax", testAdaMax),
("testAdaMax", testAdaMax)

]
}