Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit f5222cd

Browse files
mikowalsrxwei
authored andcommitted
use .moments() in LayerNorm and BatchNorm layers (#384)
Mean and variance in the layers are now calculated using `Tensor.moments()`. I also added tests for both BatchNorm and LayerNorm layers. The tests turned up a flaw in how the shape of the `scale` and `offset` which were always of shape `[featureCount]` irrespective of the input shape or axis for normalisation. That shape leads to incorrect broadcasting when the axis being normalized along is not the last axis. I have fixed this by always reshaping `scale` and `offset` before they are used. This seems hacky in that I get the shapes from the calculated `mean` and `variance`. Without the input shape being known at initialization time though I couldn't see a better way to do this. I think the axis argument is probably there to be consistent with Keras but most of the Swift api layers assume inputs and activations are NHWC. So requiring NHWC, eliminating the axis argument, and the setting the correct shapes in `init()` would be another option.
1 parent ef48ae9 commit f5222cd

File tree

2 files changed

+99
-13
lines changed

2 files changed

+99
-13
lines changed

Sources/TensorFlow/Layers/Normalization.swift

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -70,18 +70,19 @@ public struct BatchNorm<Scalar: TensorFlowFloatingPoint>: Layer {
7070
let positiveAxis = (input.rank + axis) % input.rank
7171
var normalizedAxes = Array(0..<input.rank)
7272
normalizedAxes.remove(at: positiveAxis)
73-
let mean = input.mean(alongAxes: normalizedAxes)
74-
let variance = input.variance(alongAxes: normalizedAxes)
75-
runningMean.value += (mean - runningMean.value) * (1 - momentum)
76-
runningVariance.value += (variance - runningVariance.value) * (1 - momentum)
77-
let inv = rsqrt(variance + epsilon) * scale
78-
return (input - mean) * inv + offset
73+
let moments = input.moments(alongAxes: normalizedAxes)
74+
runningMean.value += (moments.mean - runningMean.value) * (1 - momentum)
75+
runningVariance.value += (moments.variance - runningVariance.value) * (1 - momentum)
76+
let inv = rsqrt(moments.variance + epsilon) * scale.reshaped(to: moments.variance.shape)
77+
return (input - moments.mean) * inv + offset.reshaped(to: moments.mean.shape)
7978
}
8079

8180
@differentiable
8281
private func applyingInference(to input: Tensor<Scalar>) -> Tensor<Scalar> {
83-
let inv = rsqrt(runningVariance.value + epsilon) * scale
84-
return (input - runningMean.value) * inv + offset
82+
let scaleShape = runningVariance.value.shape
83+
let offsetShape = runningMean.value.shape
84+
let inv = rsqrt(runningVariance.value + epsilon) * scale.reshaped(to: scaleShape)
85+
return (input - runningMean.value) * inv + offset.reshaped(to: offsetShape)
8586
}
8687

8788
/// Returns the output obtained from applying the layer to the given input.
@@ -187,9 +188,8 @@ public struct LayerNorm<Scalar: TensorFlowFloatingPoint>: Layer {
187188
/// - Returns: The output.
188189
@differentiable
189190
public func callAsFunction(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
190-
let mean = input.mean(alongAxes: axis)
191-
let variance = input.variance(alongAxes: axis)
192-
let inv = rsqrt(variance + epsilon) * scale
193-
return (input - mean) * inv + offset
191+
let moments = input.moments(alongAxes: axis)
192+
let inv = rsqrt(moments.variance + epsilon) * scale.reshaped(to: moments.variance.shape)
193+
return (input - moments.mean) * inv + offset.reshaped(to: moments.mean.shape)
194194
}
195195
}

Tests/TensorFlowTests/LayerTests.swift

Lines changed: 87 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -410,6 +410,90 @@ final class LayerTests: XCTestCase {
410410
let expected = Tensor<Float>([[0.0], [0.7615942], [0.9640276], [0.9950547], [0.9993292]])
411411
XCTAssertEqual(output, expected)
412412
}
413+
414+
func testBatchNorm() {
415+
let x = Tensor<Float>([
416+
[ -1.0474433, -0.11914538, -0.08634827, 0.15446888, 1.0572497],
417+
[ 1.5165012, 0.3753972, -0.30856386, -0.3100725, -1.9584457],
418+
[ 0.006384419, 1.4424847, 0.91568077, 0.66328526, -1.0794537],
419+
[ 1.056803, 0.14263044, -1.8308276, 0.4189805, 0.6933893],
420+
[ 0.30175626, -0.16121633, -0.4191958, -0.53092813, -0.029484272]])
421+
let bnLayer = BatchNorm<Float>(featureCount: 5, axis: 0)
422+
Context.local.learningPhase = .training
423+
let trainingValue = bnLayer(x)
424+
let grad = gradient(at: x, bnLayer) { $1($0).squared().sum() }
425+
// The expected values and gradients were computed using the following Python code:
426+
// ```
427+
// x = tf.constant(
428+
// [[ -1.0474433, -0.11914538, -0.08634827, 0.15446888, 1.0572497],
429+
// [ 1.5165012, 0.3753972, -0.30856386, -0.3100725, -1.9584457],
430+
// [ 0.006384419, 1.4424847, 0.91568077, 0.66328526, -1.0794537],
431+
// [ 1.056803, 0.14263044, -1.8308276, 0.4189805, 0.6933893],
432+
// [ 0.30175626, -0.16121633, -0.4191958, -0.53092813, -0.029484272]])
433+
// scale = tf.reshape(tf.constant([1., 1., 1., 1., 1.]), [5, 1])
434+
// offset = tf.reshape(tf.constant([0., 0., 0., 0., 0.]), [5, 1])
435+
// (mean, var) = tf.nn.moments(x, axes=1, keepdims=True)
436+
// bn = tf.nn.batch_normalization( x, mean, var, offset=offset, scale=scale, variance_epsilon=0.001)
437+
// scaled = tf.reduce_sum(tf.square(bn))
438+
// g = tf.gradients(scaled, [x, offset, scale])
439+
// init = tf.initialize_all_variables()
440+
// with tf.Session() as sess:
441+
// sess.run(init)
442+
// print(sess.run([bn, g]))
443+
// ```
444+
let expectedTrainingValue = Tensor<Float>([
445+
[-1.5439795 , -0.16477099, -0.11604305, 0.24174842, 1.5830451 ],
446+
[ 1.4639764 , 0.45368853, -0.15186328, -0.15319899, -1.6126028 ],
447+
[-0.44139984, 1.2124169 , 0.60574806, 0.3150888 , -1.6918538 ],
448+
[ 0.9507547 , 0.04595902, -1.9072568 , 0.31947452, 0.5910686 ],
449+
[ 1.5834246 , 0.02224666, -0.8476793 , -1.2244489 , 0.46645695]])
450+
451+
let expectedInputGradient = Tensor<Float>([
452+
[-1.0127544e-02, -1.0807812e-03, -7.6115131e-04, 1.5857220e-03, 1.0383606e-02],
453+
[ 2.0323221e-03, 6.2976527e-04, -2.1077941e-04, -2.1265696e-04, -2.2384699e-03],
454+
[-1.3483668e-03, 3.7030075e-03, 1.8500184e-03, 9.6232636e-04, -5.1673558e-03],
455+
[ 1.8438101e-03, 8.9146197e-05, -3.6990643e-03, 6.1964989e-04, 1.1463165e-03],
456+
[ 1.2142579e-01, 1.7060755e-03, -6.5005139e-02, -9.3897656e-02, 3.5770576e-02]])
457+
let expectedScaleGradient = Tensor<Float>([9.977925, 9.992161, 9.986738, 9.990202, 9.886292])
458+
let expectedOffsetGradient = Tensor<Float>([0.0, 0.0, 0.0, 0.0, 0.0])
459+
assertEqual(expectedTrainingValue, trainingValue, accuracy: 1e-5)
460+
assertEqual(expectedInputGradient, grad.0, accuracy: 1e-5)
461+
assertEqual(expectedScaleGradient, grad.1.scale, accuracy: 1e-5)
462+
assertEqual(expectedOffsetGradient, grad.1.offset, accuracy: 1e-5)
463+
}
464+
465+
func testLayerNorm() {
466+
let x = Tensor<Float>([
467+
[ -1.0474433, -0.11914538, -0.08634827, 0.15446888, 1.0572497],
468+
[ 1.5165012, 0.3753972, -0.30856386, -0.3100725, -1.9584457],
469+
[ 0.006384419, 1.4424847, 0.91568077, 0.66328526, -1.0794537],
470+
[ 1.056803, 0.14263044, -1.8308276, 0.4189805, 0.6933893],
471+
[ 0.30175626, -0.16121633, -0.4191958, -0.53092813, -0.029484272]])
472+
let lnLayer = LayerNorm<Float>(featureCount: 5, axis: 1)
473+
let value = lnLayer(x)
474+
let grad = gradient(at: x, lnLayer) { $1($0).squared().sum() }
475+
// Uses the same values as testBatchNorm() above because LayerNorm with features on axis 1
476+
// is equivalent to BatchNorm with features on axis 0
477+
let expectedValue = Tensor<Float>([
478+
[-1.5439795 , -0.16477099, -0.11604305, 0.24174842, 1.5830451 ],
479+
[ 1.4639764 , 0.45368853, -0.15186328, -0.15319899, -1.6126028 ],
480+
[-0.44139984, 1.2124169 , 0.60574806, 0.3150888 , -1.6918538 ],
481+
[ 0.9507547 , 0.04595902, -1.9072568 , 0.31947452, 0.5910686 ],
482+
[ 1.5834246 , 0.02224666, -0.8476793 , -1.2244489 , 0.46645695]])
483+
484+
let expectedInputGradient = Tensor<Float>([
485+
[-1.0127544e-02, -1.0807812e-03, -7.6115131e-04, 1.5857220e-03, 1.0383606e-02],
486+
[ 2.0323221e-03, 6.2976527e-04, -2.1077941e-04, -2.1265696e-04, -2.2384699e-03],
487+
[-1.3483668e-03, 3.7030075e-03, 1.8500184e-03, 9.6232636e-04, -5.1673558e-03],
488+
[ 1.8438101e-03, 8.9146197e-05, -3.6990643e-03, 6.1964989e-04, 1.1463165e-03],
489+
[ 1.2142579e-01, 1.7060755e-03, -6.5005139e-02, -9.3897656e-02, 3.5770576e-02]])
490+
let expectedScaleGradient = Tensor<Float>([9.977925, 9.992161, 9.986738, 9.990202, 9.886292])
491+
let expectedOffsetGradient = Tensor<Float>([0.0, 0.0, 0.0, 0.0, 0.0])
492+
assertEqual(expectedValue, value, accuracy: 1e-5)
493+
assertEqual(expectedInputGradient, grad.0, accuracy: 1e-5)
494+
assertEqual(expectedScaleGradient, grad.1.scale, accuracy: 1e-5)
495+
assertEqual(expectedOffsetGradient, grad.1.offset, accuracy: 1e-5)
496+
}
413497

414498
static var allTests = [
415499
("testSequential", testSequential),
@@ -443,6 +527,8 @@ final class LayerTests: XCTestCase {
443527
("testSimpleRNNCell", testSimpleRNNCell),
444528
("testDense", testDense),
445529
("testRNN", testRNN),
446-
("testFunction", testFunction)
530+
("testFunction", testFunction),
531+
("testBatchNorm", testBatchNorm),
532+
("testLayerNorm", testLayerNorm)
447533
]
448534
}

0 commit comments

Comments
 (0)