Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

use .moments() in LayerNorm and BatchNorm layers #384

Merged
merged 6 commits into from
Jul 24, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions Sources/TensorFlow/Layers/Normalization.swift
Original file line number Diff line number Diff line change
Expand Up @@ -70,18 +70,19 @@ public struct BatchNorm<Scalar: TensorFlowFloatingPoint>: Layer {
let positiveAxis = (input.rank + axis) % input.rank
var normalizedAxes = Array(0..<input.rank)
normalizedAxes.remove(at: positiveAxis)
let mean = input.mean(alongAxes: normalizedAxes)
let variance = input.variance(alongAxes: normalizedAxes)
runningMean.value += (mean - runningMean.value) * (1 - momentum)
runningVariance.value += (variance - runningVariance.value) * (1 - momentum)
let inv = rsqrt(variance + epsilon) * scale
return (input - mean) * inv + offset
let moments = input.moments(alongAxes: normalizedAxes)
runningMean.value += (moments.mean - runningMean.value) * (1 - momentum)
runningVariance.value += (moments.variance - runningVariance.value) * (1 - momentum)
let inv = rsqrt(moments.variance + epsilon) * scale.reshaped(to: moments.variance.shape)
return (input - moments.mean) * inv + offset.reshaped(to: moments.mean.shape)
}

@differentiable
private func applyingInference(to input: Tensor<Scalar>) -> Tensor<Scalar> {
let inv = rsqrt(runningVariance.value + epsilon) * scale
return (input - runningMean.value) * inv + offset
let scaleShape = runningVariance.value.shape
let offsetShape = runningMean.value.shape
let inv = rsqrt(runningVariance.value + epsilon) * scale.reshaped(to: scaleShape)
return (input - runningMean.value) * inv + offset.reshaped(to: offsetShape)
}

/// Returns the output obtained from applying the layer to the given input.
Expand Down Expand Up @@ -187,9 +188,8 @@ public struct LayerNorm<Scalar: TensorFlowFloatingPoint>: Layer {
/// - Returns: The output.
@differentiable
public func callAsFunction(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
let mean = input.mean(alongAxes: axis)
let variance = input.variance(alongAxes: axis)
let inv = rsqrt(variance + epsilon) * scale
return (input - mean) * inv + offset
let moments = input.moments(alongAxes: axis)
let inv = rsqrt(moments.variance + epsilon) * scale.reshaped(to: moments.variance.shape)
return (input - moments.mean) * inv + offset.reshaped(to: moments.mean.shape)
}
}
88 changes: 87 additions & 1 deletion Tests/TensorFlowTests/LayerTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,90 @@ final class LayerTests: XCTestCase {
let expected = Tensor<Float>([[0.0], [0.7615942], [0.9640276], [0.9950547], [0.9993292]])
XCTAssertEqual(output, expected)
}

func testBatchNorm() {
let x = Tensor<Float>([
[ -1.0474433, -0.11914538, -0.08634827, 0.15446888, 1.0572497],
[ 1.5165012, 0.3753972, -0.30856386, -0.3100725, -1.9584457],
[ 0.006384419, 1.4424847, 0.91568077, 0.66328526, -1.0794537],
[ 1.056803, 0.14263044, -1.8308276, 0.4189805, 0.6933893],
[ 0.30175626, -0.16121633, -0.4191958, -0.53092813, -0.029484272]])
let bnLayer = BatchNorm<Float>(featureCount: 5, axis: 0)
Context.local.learningPhase = .training
let trainingValue = bnLayer(x)
let grad = gradient(at: x, bnLayer) { $1($0).squared().sum() }
// The expected values and gradients were computed using the following Python code:
// ```
// x = tf.constant(
// [[ -1.0474433, -0.11914538, -0.08634827, 0.15446888, 1.0572497],
// [ 1.5165012, 0.3753972, -0.30856386, -0.3100725, -1.9584457],
// [ 0.006384419, 1.4424847, 0.91568077, 0.66328526, -1.0794537],
// [ 1.056803, 0.14263044, -1.8308276, 0.4189805, 0.6933893],
// [ 0.30175626, -0.16121633, -0.4191958, -0.53092813, -0.029484272]])
// scale = tf.reshape(tf.constant([1., 1., 1., 1., 1.]), [5, 1])
// offset = tf.reshape(tf.constant([0., 0., 0., 0., 0.]), [5, 1])
// (mean, var) = tf.nn.moments(x, axes=1, keepdims=True)
// bn = tf.nn.batch_normalization( x, mean, var, offset=offset, scale=scale, variance_epsilon=0.001)
// scaled = tf.reduce_sum(tf.square(bn))
// g = tf.gradients(scaled, [x, offset, scale])
// init = tf.initialize_all_variables()
// with tf.Session() as sess:
// sess.run(init)
// print(sess.run([bn, g]))
// ```
let expectedTrainingValue = Tensor<Float>([
[-1.5439795 , -0.16477099, -0.11604305, 0.24174842, 1.5830451 ],
[ 1.4639764 , 0.45368853, -0.15186328, -0.15319899, -1.6126028 ],
[-0.44139984, 1.2124169 , 0.60574806, 0.3150888 , -1.6918538 ],
[ 0.9507547 , 0.04595902, -1.9072568 , 0.31947452, 0.5910686 ],
[ 1.5834246 , 0.02224666, -0.8476793 , -1.2244489 , 0.46645695]])

let expectedInputGradient = Tensor<Float>([
[-1.0127544e-02, -1.0807812e-03, -7.6115131e-04, 1.5857220e-03, 1.0383606e-02],
[ 2.0323221e-03, 6.2976527e-04, -2.1077941e-04, -2.1265696e-04, -2.2384699e-03],
[-1.3483668e-03, 3.7030075e-03, 1.8500184e-03, 9.6232636e-04, -5.1673558e-03],
[ 1.8438101e-03, 8.9146197e-05, -3.6990643e-03, 6.1964989e-04, 1.1463165e-03],
[ 1.2142579e-01, 1.7060755e-03, -6.5005139e-02, -9.3897656e-02, 3.5770576e-02]])
let expectedScaleGradient = Tensor<Float>([9.977925, 9.992161, 9.986738, 9.990202, 9.886292])
let expectedOffsetGradient = Tensor<Float>([0.0, 0.0, 0.0, 0.0, 0.0])
assertEqual(expectedTrainingValue, trainingValue, accuracy: 1e-5)
assertEqual(expectedInputGradient, grad.0, accuracy: 1e-5)
assertEqual(expectedScaleGradient, grad.1.scale, accuracy: 1e-5)
assertEqual(expectedOffsetGradient, grad.1.offset, accuracy: 1e-5)
}

func testLayerNorm() {
let x = Tensor<Float>([
[ -1.0474433, -0.11914538, -0.08634827, 0.15446888, 1.0572497],
[ 1.5165012, 0.3753972, -0.30856386, -0.3100725, -1.9584457],
[ 0.006384419, 1.4424847, 0.91568077, 0.66328526, -1.0794537],
[ 1.056803, 0.14263044, -1.8308276, 0.4189805, 0.6933893],
[ 0.30175626, -0.16121633, -0.4191958, -0.53092813, -0.029484272]])
let lnLayer = LayerNorm<Float>(featureCount: 5, axis: 1)
let value = lnLayer(x)
let grad = gradient(at: x, lnLayer) { $1($0).squared().sum() }
// Uses the same values as testBatchNorm() above because LayerNorm with features on axis 1
// is equivalent to BatchNorm with features on axis 0
let expectedValue = Tensor<Float>([
[-1.5439795 , -0.16477099, -0.11604305, 0.24174842, 1.5830451 ],
[ 1.4639764 , 0.45368853, -0.15186328, -0.15319899, -1.6126028 ],
[-0.44139984, 1.2124169 , 0.60574806, 0.3150888 , -1.6918538 ],
[ 0.9507547 , 0.04595902, -1.9072568 , 0.31947452, 0.5910686 ],
[ 1.5834246 , 0.02224666, -0.8476793 , -1.2244489 , 0.46645695]])

let expectedInputGradient = Tensor<Float>([
[-1.0127544e-02, -1.0807812e-03, -7.6115131e-04, 1.5857220e-03, 1.0383606e-02],
[ 2.0323221e-03, 6.2976527e-04, -2.1077941e-04, -2.1265696e-04, -2.2384699e-03],
[-1.3483668e-03, 3.7030075e-03, 1.8500184e-03, 9.6232636e-04, -5.1673558e-03],
[ 1.8438101e-03, 8.9146197e-05, -3.6990643e-03, 6.1964989e-04, 1.1463165e-03],
[ 1.2142579e-01, 1.7060755e-03, -6.5005139e-02, -9.3897656e-02, 3.5770576e-02]])
let expectedScaleGradient = Tensor<Float>([9.977925, 9.992161, 9.986738, 9.990202, 9.886292])
let expectedOffsetGradient = Tensor<Float>([0.0, 0.0, 0.0, 0.0, 0.0])
assertEqual(expectedValue, value, accuracy: 1e-5)
assertEqual(expectedInputGradient, grad.0, accuracy: 1e-5)
assertEqual(expectedScaleGradient, grad.1.scale, accuracy: 1e-5)
assertEqual(expectedOffsetGradient, grad.1.offset, accuracy: 1e-5)
}

static var allTests = [
("testSequential", testSequential),
Expand Down Expand Up @@ -443,6 +527,8 @@ final class LayerTests: XCTestCase {
("testSimpleRNNCell", testSimpleRNNCell),
("testDense", testDense),
("testRNN", testRNN),
("testFunction", testFunction)
("testFunction", testFunction),
("testBatchNorm", testBatchNorm),
("testLayerNorm", testLayerNorm)
]
}