@@ -155,8 +155,9 @@ public struct BatchNorm<Scalar>: Layer
155
155
156
156
@differentiable ( wrt: ( self , input) )
157
157
private func applyTraining( to input: Tensor < Scalar > ) -> Tensor < Scalar > {
158
- let mean = input. mean ( alongAxes: axis)
159
- let variance = input. variance ( alongAxes: axis)
158
+ let positiveAxis = ( input. rank + axis) % input. rank
159
+ let mean = input. mean ( alongAxes: [ 0 , positiveAxis] )
160
+ let variance = input. variance ( alongAxes: [ 0 , positiveAxis] )
160
161
runningMean. value += ( mean - runningMean. value) * ( 1 - momentum)
161
162
runningVariance. value += (
162
163
variance - runningVariance. value) * ( 1 - momentum)
@@ -170,41 +171,33 @@ public struct BatchNorm<Scalar>: Layer
170
171
return ( input - runningMean. value) * inv + offset
171
172
}
172
173
173
- // TODO fix crasher in the below to enable behavior that differs between
174
- // training and inference
175
- //
176
- // @differentiable(wrt: (self, input), vjp: _vjpApplied(to:))
177
- // public func applied(to input: Tensor<Scalar>) -> Tensor<Scalar> {
178
- // if learningPhaseIndicator.training {
179
- // return applyTraining(to: input)
180
- // } else {
181
- // return applyInference(to: input)
182
- // }
183
- // }
184
- //
185
- // public func _vjpApplied(to input: Tensor<Scalar>) ->
186
- // (Tensor<Scalar>, (Tensor<Scalar>) ->
187
- // (BatchNorm<Scalar>.CotangentVector, Tensor<Scalar>)) {
188
- // if learningPhaseIndicator.training {
189
- // return self.valueWithPullback(at: input) {
190
- // $0.applyTraining(to: $1)
191
- // }
192
- // } else {
193
- // return self.valueWithPullback(at: input) {
194
- // $0.applyInference(to: $1)
195
- // }
196
- // }
197
- // }
198
- //
199
- // Work around for now by always using training mode
200
- @differentiable ( wrt: ( self , input) )
174
+ @differentiable ( wrt: ( self , input) , vjp: _vjpApplied ( to: ) )
201
175
public func applied( to input: Tensor < Scalar > ) -> Tensor < Scalar > {
202
- return applyTraining ( to: input)
176
+ if learningPhaseIndicator. training {
177
+ return applyTraining ( to: input)
178
+ } else {
179
+ return applyInference ( to: input)
180
+ }
181
+ }
182
+
183
+ @usableFromInline
184
+ func _vjpApplied( to input: Tensor < Scalar > ) ->
185
+ ( Tensor < Scalar > , ( Tensor < Scalar > ) ->
186
+ ( BatchNorm < Scalar > . CotangentVector , Tensor < Scalar > ) ) {
187
+ if learningPhaseIndicator. training {
188
+ return self . valueWithPullback ( at: input) {
189
+ $0. applyTraining ( to: $1)
190
+ }
191
+ } else {
192
+ return self . valueWithPullback ( at: input) {
193
+ $0. applyInference ( to: $1)
194
+ }
195
+ }
203
196
}
204
197
205
198
public init ( featureCount: Int ,
206
199
learningPhaseIndicator: LearningPhaseIndicator ,
207
- axis: Int = 0 ,
200
+ axis: Int = - 1 ,
208
201
momentum: Tensor < Scalar > = Tensor ( 0.99 ) ,
209
202
epsilon: Tensor < Scalar > = Tensor ( 0.001 ) ) {
210
203
self . axis = Int32 ( axis)
0 commit comments