@@ -71,16 +71,24 @@ public struct BatchNorm<Scalar: TensorFlowFloatingPoint>: Layer {
71
71
/// - Returns: The output.
72
72
@differentiable
73
73
public func callAsFunction( _ input: Tensor < Scalar > ) -> Tensor < Scalar > {
74
+ let positiveAxis = ( input. rank + axis) % input. rank
75
+ var offset = self . offset
76
+ var scale = self . scale
77
+ if positiveAxis != input. rank - 1 {
78
+ var broadcastShape = TensorShape ( [ Int] ( repeating: 1 , count: input. rank) )
79
+ broadcastShape [ positiveAxis] = input. shape [ positiveAxis]
80
+ offset = offset. reshaped ( to: broadcastShape)
81
+ scale = scale. reshaped ( to: broadcastShape)
82
+ }
74
83
switch Context . local. learningPhase {
75
84
case . training:
76
- let positiveAxis = ( input. rank + axis) % input. rank
77
85
var normalizedAxes = Array ( 0 ..< input. rank)
78
86
normalizedAxes. remove ( at: positiveAxis)
79
87
let moments = input. moments ( alongAxes: normalizedAxes)
80
88
runningMean. value += ( moments. mean - runningMean. value) * ( 1 - momentum)
81
89
runningVariance. value += ( moments. variance - runningVariance. value) * ( 1 - momentum)
82
- let inv = rsqrt ( moments. variance + epsilon) * scale. reshaped ( to : moments . variance . shape )
83
- return ( input - moments. mean) * inv + offset. reshaped ( to : moments . mean . shape )
90
+ let inv = rsqrt ( moments. variance + epsilon) * scale
91
+ return ( input - moments. mean) * inv + offset
84
92
case . inference:
85
93
let inv = rsqrt ( runningVariance. value + epsilon) * scale
86
94
return ( input - runningMean. value) * inv + offset
@@ -100,13 +108,14 @@ public struct BatchNorm<Scalar: TensorFlowFloatingPoint>: Layer {
100
108
momentum: Tensor < Scalar > = Tensor ( 0.99 ) ,
101
109
epsilon: Tensor < Scalar > = Tensor ( 0.001 )
102
110
) {
103
- self . axis = axis
104
- self . momentum = momentum
105
- self . scale = Tensor < Scalar > ( ones: [ featureCount] )
106
- self . offset = Tensor < Scalar > ( zeros: [ featureCount] )
107
- self . epsilon = epsilon
108
- self . runningMean = Parameter ( Tensor ( 0 ) )
109
- self . runningVariance = Parameter ( Tensor ( 1 ) )
111
+ self . init (
112
+ axis: axis,
113
+ momentum: momentum,
114
+ offset: Tensor ( zeros: [ featureCount] ) ,
115
+ scale: Tensor ( ones: [ featureCount] ) ,
116
+ epsilon: epsilon,
117
+ runningMean: Tensor ( 0 ) ,
118
+ runningVariance: Tensor ( 1 ) )
110
119
}
111
120
}
112
121
@@ -152,8 +161,7 @@ public struct LayerNorm<Scalar: TensorFlowFloatingPoint>: Layer {
152
161
offset: Tensor ( zeros: [ featureCount] ) ,
153
162
scale: Tensor ( ones: [ featureCount] ) ,
154
163
axis: axis,
155
- epsilon: epsilon
156
- )
164
+ epsilon: epsilon)
157
165
}
158
166
159
167
/// Returns the output obtained from applying the layer to the given input.
@@ -162,8 +170,13 @@ public struct LayerNorm<Scalar: TensorFlowFloatingPoint>: Layer {
162
170
/// - Returns: The output.
163
171
@differentiable
164
172
public func callAsFunction( _ input: Tensor < Scalar > ) -> Tensor < Scalar > {
173
+ let positiveAxis = ( input. rank + axis) % input. rank
174
+ var broadcastShape = input. shape
175
+ broadcastShape [ positiveAxis] = 1
176
+ let offset = self . offset. reshaped ( to: broadcastShape)
177
+ let scale = self . scale. reshaped ( to: broadcastShape)
165
178
let moments = input. moments ( alongAxes: axis)
166
- let inv = rsqrt ( moments. variance + epsilon) * scale. reshaped ( to : moments . variance . shape )
167
- return ( input - moments. mean) * inv + offset. reshaped ( to : moments . mean . shape )
179
+ let inv = rsqrt ( moments. variance + epsilon) * scale
180
+ return ( input - moments. mean) * inv + offset
168
181
}
169
182
}
0 commit comments