Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

[Optimizer] Simplify optimizers using generalized vector math. #218

Merged
merged 15 commits into from
Jun 27, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Sources/TensorFlow/Loss.swift
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ public func sigmoidCrossEntropy<Scalar: TensorFlowFloatingPoint>(
// This numerical stable implementation is based on tf.nn.sigmoid_cross_entropy_with_logits.

let maxLogitsWithZero = max(logits, Tensor(0))
let loss = maxLogitsWithZero - logits * labels + log(1 + exp(-abs(logits)))
var loss = maxLogitsWithZero - logits * labels
loss = loss + log(1 + exp(-abs(logits)))
return loss.mean()
}
54 changes: 43 additions & 11 deletions Sources/TensorFlow/Operators/Math.swift
Original file line number Diff line number Diff line change
Expand Up @@ -161,22 +161,54 @@ extension Tensor: ElementaryFunctions where Scalar: TensorFlowFloatingPoint {
// Vector Space
//===------------------------------------------------------------------------------------------===//

extension Tensor: VectorProtocol where Scalar: Numeric {
public typealias VectorSpaceScalar = Scalar
extension Tensor: VectorProtocol where Scalar: TensorFlowFloatingPoint {
public typealias VectorSpaceScalar = Float

@differentiable(where Scalar: TensorFlowFloatingPoint)
public func adding(_ scalar: Scalar) -> Self {
self + scalar
// @differentiable(where Scalar: TensorFlowFloatingPoint)
public func scaled(by scale: Float) -> Self {
Scalar(scale) * self
}

@differentiable(where Scalar: TensorFlowFloatingPoint)
public func subtracting(_ scalar: Scalar) -> Self {
self - scalar
// @differentiable(where Scalar: TensorFlowFloatingPoint)
public func adding(_ scalar: Float) -> Self {
self + Scalar(scalar)
}

@differentiable(where Scalar: TensorFlowFloatingPoint)
public func scaled(by scalar: Scalar) -> Self {
self * scalar
// @differentiable(where Scalar: TensorFlowFloatingPoint)
public func subtracting(_ scalar: Float) -> Self {
self - Scalar(scalar)
}
}

extension VectorProtocol {
static func + (lhs: VectorSpaceScalar, rhs: Self) -> Self {
rhs.adding(lhs)
}

static func + (lhs: Self, rhs: VectorSpaceScalar) -> Self {
lhs.adding(rhs)
}

static func - (lhs: Self, rhs: VectorSpaceScalar) -> Self {
lhs.subtracting(rhs)
}

static func * (lhs: VectorSpaceScalar, rhs: Self) -> Self {
rhs.scaled(by: lhs)
}

static func * (lhs: Self, rhs: VectorSpaceScalar) -> Self {
lhs.scaled(by: rhs)
}
}

extension VectorProtocol where VectorSpaceScalar: SignedNumeric {
static prefix func - (x: Self) -> Self {
.zero - x
}

static func - (lhs: VectorSpaceScalar, rhs: Self) -> Self {
(-rhs).adding(lhs)
}
}

Expand Down
87 changes: 45 additions & 42 deletions Sources/TensorFlow/Optimizer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ public protocol Optimizer {
var learningRate: Scalar { get set }
/// Updates the specified differentiable variables along the specified
/// direction.
mutating func update(_ variables: inout Model.AllDifferentiableVariables,
along direction: Model.TangentVector)
mutating func update(_ variables: inout Model, along direction: Model.TangentVector)
}

fileprivate extension Tensor where Scalar: Numeric {
Expand All @@ -35,14 +34,13 @@ fileprivate extension Tensor where Scalar: Numeric {
}
}

// MARK: - Key-path based optimizers

/// Adam optimizer.
///
/// Reference: ["Adam - A Method for Stochastic Optimization"](
/// https://arxiv.org/abs/1412.6980v8)
public class Adam<Model: Layer>: Optimizer
where Model.AllDifferentiableVariables == Model.TangentVector {
public typealias Model = Model
/// The learning rate.
public var learningRate: Float
/// A coefficient used to calculate the first and second moments of
Expand Down Expand Up @@ -96,7 +94,7 @@ public class Adam<Model: Layer>: Optimizer
}
}


// TODO: Deprecate this when `Differentiable.AllDifferentiableVariables` is removed.
public func update(_ model: inout Model.AllDifferentiableVariables,
along direction: Model.AllDifferentiableVariables) {
step += 1
Expand Down Expand Up @@ -127,6 +125,11 @@ public class Adam<Model: Layer>: Optimizer
sqrt(secondMoments[keyPath: kp]) + Double(epsilon)
}
}

public func update(_ model: inout Model,
along direction: Model.TangentVector) {
update(&model.allDifferentiableVariables, along: direction)
}
}

/// RMSProp optimizer.
Expand All @@ -139,6 +142,7 @@ public class Adam<Model: Layer>: Optimizer
/// http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
public class RMSProp<Model: Layer>: Optimizer
where Model.AllDifferentiableVariables == Model.TangentVector {
public typealias Model = Model
/// The learning rate.
public var learningRate: Float
// TODO: Document `rho`. Keras doesn't document `rho`.
Expand Down Expand Up @@ -176,7 +180,7 @@ public class RMSProp<Model: Layer>: Optimizer
}
}


// TODO: Deprecate this when `Differentiable.AllDifferentiableVariables` is removed.
public func update(_ model: inout Model.AllDifferentiableVariables,
along direction: Model.TangentVector) {
step += 1
Expand All @@ -195,14 +199,21 @@ public class RMSProp<Model: Layer>: Optimizer
(sqrt(alpha[keyPath: kp]) + Double(epsilon))
}
}

public func update(_ model: inout Model,
along direction: Model.TangentVector) {
update(&model.allDifferentiableVariables, along: direction)
}
}

/// Stochastic gradient descent (SGD) optimizer.
///
/// An optimizer that implements stochastic gradient descent, with support for momentum, learning
/// rate decay, and Nesterov momentum.
public class SGD<Model: Layer>: Optimizer
where Model.AllDifferentiableVariables == Model.TangentVector {
public class SGD<Model: Differentiable>: Optimizer
where Model.TangentVector: VectorProtocol & ElementaryFunctions,
Model.TangentVector.VectorSpaceScalar == Float {
public typealias Model = Model
/// The learning rate.
public var learningRate: Float
/// The momentum factor. It accelerates stochastic gradient descent in the relevant direction
Expand All @@ -212,8 +223,8 @@ public class SGD<Model: Layer>: Optimizer
public var decay: Float
/// Use Nesterov momentum if true.
public var nesterov: Bool
/// The velocity state of the model
public var velocity: Model.AllDifferentiableVariables
/// The velocity state of the model.
public var velocity: Model.TangentVector = .zero
/// The set of steps taken.
public var step: Int = 0

Expand All @@ -232,53 +243,38 @@ public class SGD<Model: Layer>: Optimizer
self.momentum = momentum
self.decay = decay
self.nesterov = nesterov
velocity = model.allDifferentiableVariables
for kp in velocity.recursivelyAllWritableKeyPaths(to: Tensor<Float>.self) {
velocity[keyPath: kp].resetToZero()
}
for kp in velocity.recursivelyAllWritableKeyPaths(to: Tensor<Double>.self) {
velocity[keyPath: kp].resetToZero()
}
}

// TODO: Deprecate this when `Differentiable.AllDifferentiableVariables` is removed.
public func update(_ model: inout Model.AllDifferentiableVariables,
along direction: Model.TangentVector) {
step += 1
let learningRate = self.learningRate * 1 / (1 + decay * Float(step))
for kp in model.recursivelyAllWritableKeyPaths(to: Tensor<Float>.self) {
velocity[keyPath: kp] =
momentum * velocity[keyPath: kp] - learningRate * direction[keyPath: kp]
if nesterov {
model[keyPath: kp] +=
momentum * velocity[keyPath: kp] - learningRate * direction[keyPath: kp]
} else {
model[keyPath: kp] += velocity[keyPath: kp]
}
}
for kp in model.recursivelyAllWritableKeyPaths(to: Tensor<Double>.self) {
velocity[keyPath: kp] =
Double(momentum) * velocity[keyPath: kp] -
Double(learningRate) * direction[keyPath: kp]
if nesterov {
model[keyPath: kp] +=
Double(momentum) * velocity[keyPath: kp] - Double(learningRate) *
direction[keyPath: kp]
} else {
model[keyPath: kp] += velocity[keyPath: kp]
}
velocity = momentum * velocity - direction * learningRate
if nesterov {
model.move(along: momentum * velocity - direction * learningRate)
} else {
model.move(along: velocity)
}
}

public func update(_ model: inout Model,
along direction: Model.TangentVector) {
update(&model.allDifferentiableVariables, along: direction)
}
}

// MARK: - Manifold optimizers

/// A Riemann manifold stochastic gradient descent (SGD) optimizer.
public class RiemannSGD<Model: Layer, Scalar: FloatingPoint>: Optimizer
where Model.TangentVector: VectorProtocol, Model.TangentVector.VectorSpaceScalar == Scalar {
public class RiemannSGD<Model: Differentiable>: Optimizer
where Model.TangentVector: VectorProtocol,
Model.TangentVector.VectorSpaceScalar: FloatingPoint {
public typealias Scalar = Model.TangentVector.VectorSpaceScalar
/// The learning rate.
public var learningRate: Scalar
public var learningRate: Model.TangentVector.VectorSpaceScalar

public init(learningRate: Scalar) {
public init(learningRate: Model.TangentVector.VectorSpaceScalar) {
self.learningRate = learningRate
}

Expand All @@ -305,6 +301,7 @@ public class RiemannSGD<Model: Layer, Scalar: FloatingPoint>: Optimizer
///
public class AdaGrad<Model: Layer>: Optimizer
where Model.AllDifferentiableVariables == Model.TangentVector {
public typealias Model = Model
/// The learning rate.
public var learningRate: Float
/// The smoothing factor (ρ). Typical values are `0.5`, `0.9`, and `0.99`, for smoothing over 2,
Expand Down Expand Up @@ -337,6 +334,7 @@ public class AdaGrad<Model: Layer>: Optimizer
}
}

// TODO: Deprecate this when `Differentiable.AllDifferentiableVariables` is removed.
public func update(_ model: inout Model.AllDifferentiableVariables,
along direction: Model.TangentVector) {
for kp in model.recursivelyAllWritableKeyPaths(to: Tensor<Float>.self) {
Expand All @@ -351,4 +349,9 @@ public class AdaGrad<Model: Layer>: Optimizer
(sqrt(alpha[keyPath: kp] + Double(epsilon)))
}
}

public func update(_ model: inout Model,
along direction: Model.TangentVector) {
update(&model.allDifferentiableVariables, along: direction)
}
}
16 changes: 13 additions & 3 deletions Tests/TensorFlowTests/SequentialTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@ final class SequentialTests: XCTestCase {
}
}
var model = Model()
let optimizer = SGD(for: model, learningRate: 0.02)
let sgd = SGD(for: model, learningRate: 0.02)
let rmsprop = RMSProp(for: model, learningRate: 0.02)
let adam = Adam(for: model, learningRate: 0.02)
let adagrad = AdaGrad(for: model, learningRate: 0.02)
let x: Tensor<Float> = [[0, 0], [0, 1], [1, 0], [1, 1]]
let y: Tensor<Float> = [0, 1, 1, 0]
Context.local.learningPhase = .training
Expand All @@ -38,10 +41,17 @@ final class SequentialTests: XCTestCase {
let ŷ = model(x)
return meanSquaredError(predicted: ŷ, expected: y)
}
optimizer.update(&model.allDifferentiableVariables, along: 𝛁model)
sgd.update(&model, along: 𝛁model)
sgd.update(&model.allDifferentiableVariables, along: 𝛁model)
rmsprop.update(&model, along: 𝛁model)
rmsprop.update(&model.allDifferentiableVariables, along: 𝛁model)
adam.update(&model, along: 𝛁model)
adam.update(&model.allDifferentiableVariables, along: 𝛁model)
adagrad.update(&model, along: 𝛁model)
adagrad.update(&model.allDifferentiableVariables, along: 𝛁model)
}
XCTAssertEqual(model.inferring(from: [[0, 0], [0, 1], [1, 0], [1, 1]]),
[[ 0.4904838], [0.49942452], [0.49740878], [ 0.5106092]])
[[0.47705528], [0.47705528], [0.47705528], [0.47705528]])
}

static var allTests = [
Expand Down
2 changes: 1 addition & 1 deletion Tests/TensorFlowTests/TrivialModelTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ final class TrivialModelTests: XCTestCase {
let ŷ = classifier(x)
return meanSquaredError(predicted: ŷ, expected: y)
}
optimizer.update(&classifier.allDifferentiableVariables, along: 𝛁model)
optimizer.update(&classifier, along: 𝛁model)
}
let ŷ = classifier.inferring(from: x)
XCTAssertEqual(round(ŷ), y)
Expand Down