Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit ee72ba1

Browse files
authored
[NFC] Gardening. (#552)
Clean up `_vjpProduct` and Python TensorFlow batch normalization reference code.
1 parent 68516fb commit ee72ba1

File tree

2 files changed

+9
-9
lines changed

2 files changed

+9
-9
lines changed

Sources/TensorFlow/Operators/Math.swift

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2239,7 +2239,7 @@ internal extension Tensor where Scalar: TensorFlowFloatingPoint {
22392239
let result = product(squeezingAxes: axes)
22402240
return (result, { v in
22412241
// Reshape reduction indices for the case where the parameter is a scalar.
2242-
var reductionIndices = axes.reshaped(to: TensorShape(-1))
2242+
var reductionIndices = axes.flattened()
22432243
// Normalize any negative reduction indices to positive values.
22442244
reductionIndices = (reductionIndices + Int32(self.rank)) % Int32(self.rank)
22452245

@@ -2248,8 +2248,7 @@ internal extension Tensor where Scalar: TensorFlowFloatingPoint {
22482248
for axis in reductionIndices.scalars {
22492249
outputShape[Int(axis)] = 1
22502250
}
2251-
let vReshaped = v.reshaped(to: outputShape)
2252-
let vBroadcasted = vReshaped.broadcasted(to: self.shape)
2251+
let vBroadcasted = v.reshaped(to: outputShape).broadcasted(to: self.shape)
22532252

22542253
// Pack all reduced dimensions into a single one, so we can perform the
22552254
// `cumulativeProduct` operations.

Tests/TensorFlowTests/TensorAutoDiffTests.swift

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -509,12 +509,13 @@ final class TensorAutoDiffTests: XCTestCase {
509509
let computedGradient = gradient(at: x) { $0.batchNormalized(alongAxis: 1).squared().sum() }
510510
// The expected value of the gradient was computed using the following Python code:
511511
// ```
512-
// with tf.GradientTape() as t:
513-
// t.watch(x)
514-
// mean, var = tf.nn.moments(x, axes=1, keepdims=True)
515-
// y = tf.reduce_sum(tf.square(tf.nn.batch_normalization(
516-
// x, mean, var, offset=0, scale=1, variance_epsilon=0.001)))
517-
// print(t.gradient(y, x))
512+
// import tensorflow as tf
513+
// with tf.GradientTape() as t:
514+
// t.watch(x)
515+
// mean, var = tf.nn.moments(x, axes=1, keepdims=True)
516+
// y = tf.reduce_sum(tf.square(tf.nn.batch_normalization(
517+
// x, mean, var, offset=0, scale=1, variance_epsilon=0.001)))
518+
// print(t.gradient(y, x))
518519
// ```
519520
let expectedGradient = Tensor<Float>([
520521
[-1.0127544e-02, -1.0807812e-03, -7.6115131e-04, 1.5857220e-03, 1.0383606e-02],

0 commit comments

Comments
 (0)