Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit f218a34

Browse files
authored
Normalization layers fix (fixes #384 and #426). (#428)
1 parent daff615 commit f218a34

File tree

2 files changed

+38
-21
lines changed

2 files changed

+38
-21
lines changed

Sources/TensorFlow/Layers/Normalization.swift

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -71,16 +71,24 @@ public struct BatchNorm<Scalar: TensorFlowFloatingPoint>: Layer {
7171
/// - Returns: The output.
7272
@differentiable
7373
public func callAsFunction(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
74+
let positiveAxis = (input.rank + axis) % input.rank
75+
var offset = self.offset
76+
var scale = self.scale
77+
if positiveAxis != input.rank - 1 {
78+
var broadcastShape = TensorShape([Int](repeating: 1, count: input.rank))
79+
broadcastShape[positiveAxis] = input.shape[positiveAxis]
80+
offset = offset.reshaped(to: broadcastShape)
81+
scale = scale.reshaped(to: broadcastShape)
82+
}
7483
switch Context.local.learningPhase {
7584
case .training:
76-
let positiveAxis = (input.rank + axis) % input.rank
7785
var normalizedAxes = Array(0..<input.rank)
7886
normalizedAxes.remove(at: positiveAxis)
7987
let moments = input.moments(alongAxes: normalizedAxes)
8088
runningMean.value += (moments.mean - runningMean.value) * (1 - momentum)
8189
runningVariance.value += (moments.variance - runningVariance.value) * (1 - momentum)
82-
let inv = rsqrt(moments.variance + epsilon) * scale.reshaped(to: moments.variance.shape)
83-
return (input - moments.mean) * inv + offset.reshaped(to: moments.mean.shape)
90+
let inv = rsqrt(moments.variance + epsilon) * scale
91+
return (input - moments.mean) * inv + offset
8492
case .inference:
8593
let inv = rsqrt(runningVariance.value + epsilon) * scale
8694
return (input - runningMean.value) * inv + offset
@@ -100,13 +108,14 @@ public struct BatchNorm<Scalar: TensorFlowFloatingPoint>: Layer {
100108
momentum: Tensor<Scalar> = Tensor(0.99),
101109
epsilon: Tensor<Scalar> = Tensor(0.001)
102110
) {
103-
self.axis = axis
104-
self.momentum = momentum
105-
self.scale = Tensor<Scalar>(ones: [featureCount])
106-
self.offset = Tensor<Scalar>(zeros: [featureCount])
107-
self.epsilon = epsilon
108-
self.runningMean = Parameter(Tensor(0))
109-
self.runningVariance = Parameter(Tensor(1))
111+
self.init(
112+
axis: axis,
113+
momentum: momentum,
114+
offset: Tensor(zeros: [featureCount]),
115+
scale: Tensor(ones: [featureCount]),
116+
epsilon: epsilon,
117+
runningMean: Tensor(0),
118+
runningVariance: Tensor(1))
110119
}
111120
}
112121

@@ -152,8 +161,7 @@ public struct LayerNorm<Scalar: TensorFlowFloatingPoint>: Layer {
152161
offset: Tensor(zeros: [featureCount]),
153162
scale: Tensor(ones: [featureCount]),
154163
axis: axis,
155-
epsilon: epsilon
156-
)
164+
epsilon: epsilon)
157165
}
158166

159167
/// Returns the output obtained from applying the layer to the given input.
@@ -162,8 +170,13 @@ public struct LayerNorm<Scalar: TensorFlowFloatingPoint>: Layer {
162170
/// - Returns: The output.
163171
@differentiable
164172
public func callAsFunction(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
173+
let positiveAxis = (input.rank + axis) % input.rank
174+
var broadcastShape = input.shape
175+
broadcastShape[positiveAxis] = 1
176+
let offset = self.offset.reshaped(to: broadcastShape)
177+
let scale = self.scale.reshaped(to: broadcastShape)
165178
let moments = input.moments(alongAxes: axis)
166-
let inv = rsqrt(moments.variance + epsilon) * scale.reshaped(to: moments.variance.shape)
167-
return (input - moments.mean) * inv + offset.reshaped(to: moments.mean.shape)
179+
let inv = rsqrt(moments.variance + epsilon) * scale
180+
return (input - moments.mean) * inv + offset
168181
}
169182
}

Tests/TensorFlowTests/LayerTests.swift

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -474,7 +474,7 @@ final class LayerTests: XCTestCase {
474474
let grad = gradient(at: x, bnLayer) { $1($0).squared().sum() }
475475
// The expected values and gradients were computed using the following Python code:
476476
// ```
477-
// x = tf.constant(
477+
// x = tf.constant(
478478
// [[ -1.0474433, -0.11914538, -0.08634827, 0.15446888, 1.0572497],
479479
// [ 1.5165012, 0.3753972, -0.30856386, -0.3100725, -1.9584457],
480480
// [ 0.006384419, 1.4424847, 0.91568077, 0.66328526, -1.0794537],
@@ -509,8 +509,10 @@ final class LayerTests: XCTestCase {
509509
[ 1.2142579e-01, 1.7060755e-03, -6.5005139e-02, -9.3897656e-02, 3.5770576e-02]],
510510
accuracy: 1e-5)
511511
assertEqual(grad.1.offset, [0.0, 0.0, 0.0, 0.0, 0.0], accuracy: 1e-5)
512-
assertEqual(grad.1.scale, [9.977925, 9.992161, 9.986738, 9.990202, 9.886292],
513-
accuracy: 1e-5)
512+
assertEqual(
513+
grad.1.scale,
514+
[9.977925, 9.992161, 9.986738, 9.990202, 9.886292],
515+
accuracy: 1e-5)
514516
}
515517
}
516518

@@ -525,8 +527,8 @@ final class LayerTests: XCTestCase {
525527
let value = lnLayer(x)
526528
let grad = gradient(at: x, lnLayer) { $1($0).squared().sum() }
527529

528-
// Uses the same values as `testBatchNorm()` above because `LayerNorm` with features on axis
529-
// `1` is equivalent to `BatchNorm` with features on axis `0`.
530+
// Uses the same values as `testBatchNorm()` above because `LayerNorm` with features on
531+
// axis `1` is equivalent to `BatchNorm` with features on axis `0`.
530532
assertEqual(
531533
value,
532534
[[-1.5439795 , -0.16477099, -0.11604305, 0.24174842, 1.5830451 ],
@@ -543,9 +545,11 @@ final class LayerTests: XCTestCase {
543545
[ 1.8438101e-03, 8.9146197e-05, -3.6990643e-03, 6.1964989e-04, 1.1463165e-03],
544546
[ 1.2142579e-01, 1.7060755e-03, -6.5005139e-02, -9.3897656e-02, 3.5770576e-02]],
545547
accuracy: 1e-5)
546-
assertEqual(grad.1.scale, [9.977925, 9.992161, 9.986738, 9.990202, 9.886292],
547-
accuracy: 1e-5)
548548
assertEqual(grad.1.offset, [0.0, 0.0, 0.0, 0.0, 0.0], accuracy: 1e-5)
549+
assertEqual(
550+
grad.1.scale,
551+
[9.977925, 9.992161, 9.986738, 9.990202, 9.886292],
552+
accuracy: 1e-5)
549553
}
550554

551555
static var allTests = [

0 commit comments

Comments
 (0)