Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 93d8ea5

Browse files
committed
Batchnorm changes: fix axis handling and drop workaround for AD crasher
1 parent c8c0fa0 commit 93d8ea5

File tree

1 file changed

+25
-32
lines changed

1 file changed

+25
-32
lines changed

Sources/DeepLearning/Layer.swift

Lines changed: 25 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -155,8 +155,9 @@ public struct BatchNorm<Scalar>: Layer
155155

156156
@differentiable(wrt: (self, input))
157157
private func applyTraining(to input: Tensor<Scalar>) -> Tensor<Scalar> {
158-
let mean = input.mean(alongAxes: axis)
159-
let variance = input.variance(alongAxes: axis)
158+
let positiveAxis = (input.rank + axis) % input.rank
159+
let mean = input.mean(alongAxes: [0, positiveAxis])
160+
let variance = input.variance(alongAxes: [0, positiveAxis])
160161
runningMean.value += (mean - runningMean.value) * (1 - momentum)
161162
runningVariance.value += (
162163
variance - runningVariance.value) * (1 - momentum)
@@ -170,41 +171,33 @@ public struct BatchNorm<Scalar>: Layer
170171
return (input - runningMean.value) * inv + offset
171172
}
172173

173-
// TODO fix crasher in the below to enable behavior that differs between
174-
// training and inference
175-
//
176-
// @differentiable(wrt: (self, input), vjp: _vjpApplied(to:))
177-
// public func applied(to input: Tensor<Scalar>) -> Tensor<Scalar> {
178-
// if learningPhaseIndicator.training {
179-
// return applyTraining(to: input)
180-
// } else {
181-
// return applyInference(to: input)
182-
// }
183-
// }
184-
//
185-
// public func _vjpApplied(to input: Tensor<Scalar>) ->
186-
// (Tensor<Scalar>, (Tensor<Scalar>) ->
187-
// (BatchNorm<Scalar>.CotangentVector, Tensor<Scalar>)) {
188-
// if learningPhaseIndicator.training {
189-
// return self.valueWithPullback(at: input) {
190-
// $0.applyTraining(to: $1)
191-
// }
192-
// } else {
193-
// return self.valueWithPullback(at: input) {
194-
// $0.applyInference(to: $1)
195-
// }
196-
// }
197-
// }
198-
//
199-
// Work around for now by always using training mode
200-
@differentiable(wrt: (self, input))
174+
@differentiable(wrt: (self, input), vjp: _vjpApplied(to:))
201175
public func applied(to input: Tensor<Scalar>) -> Tensor<Scalar> {
202-
return applyTraining(to: input)
176+
if learningPhaseIndicator.training {
177+
return applyTraining(to: input)
178+
} else {
179+
return applyInference(to: input)
180+
}
181+
}
182+
183+
@usableFromInline
184+
func _vjpApplied(to input: Tensor<Scalar>) ->
185+
(Tensor<Scalar>, (Tensor<Scalar>) ->
186+
(BatchNorm<Scalar>.CotangentVector, Tensor<Scalar>)) {
187+
if learningPhaseIndicator.training {
188+
return self.valueWithPullback(at: input) {
189+
$0.applyTraining(to: $1)
190+
}
191+
} else {
192+
return self.valueWithPullback(at: input) {
193+
$0.applyInference(to: $1)
194+
}
195+
}
203196
}
204197

205198
public init(featureCount: Int,
206199
learningPhaseIndicator: LearningPhaseIndicator,
207-
axis: Int = 0,
200+
axis: Int = -1,
208201
momentum: Tensor<Scalar> = Tensor(0.99),
209202
epsilon: Tensor<Scalar> = Tensor(0.001)) {
210203
self.axis = Int32(axis)

0 commit comments

Comments
 (0)