Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 6ca3813

Browse files
committed
Add BatchNorm layer.
Compilation currently doesn't work.
1 parent 731ce40 commit 6ca3813

File tree

1 file changed

+46
-0
lines changed

1 file changed

+46
-0
lines changed

Sources/DeepLearning/Layer.swift

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,3 +87,49 @@ public extension Conv2D where Scalar : BinaryFloatingPoint,
8787
self.init(filter: Tensor(randomNormal: filterShape), strides: strides, padding: padding)
8888
}
8989
}
90+
91+
@_fixed_layout
92+
public struct BatchNorm<Scalar>: Layer
93+
where Scalar : BinaryFloatingPoint & Differentiable & TensorFlowScalar {
94+
/// The batch dimension.
95+
@noDerivative public let axis: Int32
96+
97+
/// The momentum for the running mean and running variance.
98+
@noDerivative public let momentum: Tensor<Scalar>
99+
100+
/// The offset value, also known as beta.
101+
public var offset: Tensor<Scalar>
102+
103+
/// The scale value, also known as gamma.
104+
public var scale: Tensor<Scalar>
105+
106+
/// The variance epsilon value.
107+
@noDerivative public let epsilon: Tensor<Scalar>
108+
109+
/// The running mean.
110+
@noDerivative public var runningMean: Tensor<Scalar>
111+
112+
/// The running variance.
113+
@noDerivative public var runningVariance: Tensor<Scalar>
114+
115+
@differentiable(wrt: (self, input))
116+
public func applied(to input: Tensor<Scalar>) -> Tensor<Scalar> {
117+
return input.batchNormalized(alongAxis: axis, offset: offset,
118+
scale: scale, epsilon: epsilon)
119+
}
120+
121+
public init(axis: Int32,
122+
momentum: Tensor<Scalar> = Tensor(0.99),
123+
offset: Tensor<Scalar> = Tensor(0),
124+
scale: Tensor<Scalar> = Tensor(0),
125+
epsilon: Tensor<Scalar> = Tensor(0.001)) {
126+
self.axis = axis
127+
self.momentum = momentum
128+
self.offset = offset
129+
self.scale = scale
130+
self.epsilon = epsilon
131+
/// Initialize running mean and variance to zero.
132+
self.runningMean = Tensor(0)
133+
self.runningVariance = Tensor(1)
134+
}
135+
}

0 commit comments

Comments
 (0)