Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

fix BatchNorm inference scale and offset shapes, add test #429

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions Sources/TensorFlow/Layers/Normalization.swift
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
///
/// Reference: [Batch Normalization: Accelerating Deep Network Training by Reducing Internal
/// Covariate Shift](https://arxiv.org/abs/1502.03167).


@frozen
public struct BatchNorm<Scalar: TensorFlowFloatingPoint>: Layer {
/// The feature dimension.
Expand Down Expand Up @@ -71,6 +73,14 @@ public struct BatchNorm<Scalar: TensorFlowFloatingPoint>: Layer {
/// - Returns: The output.
@differentiable
public func callAsFunction(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
var reshapedScale = scale
var reshapedOffset = offset
if axis != input.rank - 1 {
var offsetAndScaleShape = Array(repeating: 1, count: input.rank)
offsetAndScaleShape[axis] = scale.shape[0]
reshapedScale = scale.reshaped(to: TensorShape(offsetAndScaleShape))
reshapedOffset = offset.reshaped(to: TensorShape(offsetAndScaleShape))
}
switch Context.local.learningPhase {
case .training:
let positiveAxis = (input.rank + axis) % input.rank
Expand All @@ -79,11 +89,11 @@ public struct BatchNorm<Scalar: TensorFlowFloatingPoint>: Layer {
let moments = input.moments(alongAxes: normalizedAxes)
runningMean.value += (moments.mean - runningMean.value) * (1 - momentum)
runningVariance.value += (moments.variance - runningVariance.value) * (1 - momentum)
let inv = rsqrt(moments.variance + epsilon) * scale.reshaped(to: moments.variance.shape)
return (input - moments.mean) * inv + offset.reshaped(to: moments.mean.shape)
let inv = rsqrt(moments.variance + epsilon) * reshapedScale
return (input - moments.mean) * inv + reshapedOffset
case .inference:
let inv = rsqrt(runningVariance.value + epsilon) * scale
return (input - runningMean.value) * inv + offset
let inv = rsqrt(runningVariance.value + epsilon) * reshapedScale
return (input - runningMean.value) * inv + reshapedOffset
}
}

Expand Down
18 changes: 18 additions & 0 deletions Tests/TensorFlowTests/LayerTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,24 @@ final class LayerTests: XCTestCase {
}
}

func testBatchNormInference() {
let x = Tensor<Float>(rangeFrom: 0, to: 25, stride: 1).reshaped(to: [5,5])
let epsilon = Tensor<Float>(0.001)
let bnLayer = BatchNorm<Float>(featureCount: 5, axis: 0, epsilon: epsilon)
// Test inferrence before any training is only changed by epsilon value.
assertEqual(bnLayer.inferring(from: x), x / TensorFlow.sqrt(1 + epsilon), accuracy: 1e-6)
// Test inferrence after single training step.
Context.local.learningPhase = .training
let y = bnLayer(x)
assertEqual(bnLayer.inferring(from: x),
[[-0.01989088, 0.974654, 1.969199, 2.963744, 3.958289],
[ 4.9031067, 5.8976517, 6.8921967, 7.8867416, 8.881287],
[ 9.826104, 10.820649, 11.815194, 12.809739, 13.804284],
[ 14.749101, 15.743646, 16.738192, 17.732737, 18.727282],
[ 19.6721, 20.666645, 21.66119, 22.655735, 23.65028]],
accuracy: 1e-6)
}

func testLayerNorm() {
let x = Tensor<Float>([
[ -1.0474433, -0.11914538, -0.08634827, 0.15446888, 1.0572497],
Expand Down