@@ -87,3 +87,49 @@ public extension Conv2D where Scalar : BinaryFloatingPoint,
87
87
self . init ( filter: Tensor ( randomNormal: filterShape) , strides: strides, padding: padding)
88
88
}
89
89
}
90
+
91
+ @_fixed_layout
92
+ public struct BatchNorm < Scalar> : Layer
93
+ where Scalar : BinaryFloatingPoint & Differentiable & TensorFlowScalar {
94
+ /// The batch dimension.
95
+ @noDerivative public let axis : Int32
96
+
97
+ /// The momentum for the running mean and running variance.
98
+ @noDerivative public let momentum : Tensor < Scalar >
99
+
100
+ /// The offset value, also known as beta.
101
+ public var offset : Tensor < Scalar >
102
+
103
+ /// The scale value, also known as gamma.
104
+ public var scale : Tensor < Scalar >
105
+
106
+ /// The variance epsilon value.
107
+ @noDerivative public let epsilon : Tensor < Scalar >
108
+
109
+ /// The running mean.
110
+ @noDerivative public var runningMean : Tensor < Scalar >
111
+
112
+ /// The running variance.
113
+ @noDerivative public var runningVariance : Tensor < Scalar >
114
+
115
+ @differentiable ( wrt: ( self , input) )
116
+ public func applied( to input: Tensor < Scalar > ) -> Tensor < Scalar > {
117
+ return input. batchNormalized ( alongAxis: axis, offset: offset,
118
+ scale: scale, epsilon: epsilon)
119
+ }
120
+
121
+ public init ( axis: Int32 ,
122
+ momentum: Tensor < Scalar > = Tensor ( 0.99 ) ,
123
+ offset: Tensor < Scalar > = Tensor ( 0 ) ,
124
+ scale: Tensor < Scalar > = Tensor ( 0 ) ,
125
+ epsilon: Tensor < Scalar > = Tensor ( 0.001 ) ) {
126
+ self . axis = axis
127
+ self . momentum = momentum
128
+ self . offset = offset
129
+ self . scale = scale
130
+ self . epsilon = epsilon
131
+ /// Initialize running mean and variance to zero.
132
+ self . runningMean = Tensor ( 0 )
133
+ self . runningVariance = Tensor ( 1 )
134
+ }
135
+ }
0 commit comments