Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit c84d4c0

Browse files
authored
Fix derivatives for min reduction ops. (#590)
Fix VJPs for: - `Tensor.min(alongAxes:)` - `Tensor.min(squeezingAxes:)` Previously, the VJPs returned the result of `max` as the original value. Now, the result of `min` is correctly returned. Add tests, checking against Python TensorFlow.
1 parent eebb487 commit c84d4c0

File tree

2 files changed

+211
-61
lines changed

2 files changed

+211
-61
lines changed

Sources/TensorFlow/Operators/Math.swift

Lines changed: 70 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1396,6 +1396,8 @@ public func min<T>(_ lhs: Tensor<T>, _ rhs: T) -> Tensor<T> where T: Numeric & C
13961396
min(lhs, Tensor(rhs))
13971397
}
13981398

1399+
// Note: adapted from `_MinOrMaxGrad`:
1400+
// https://github.com/tensorflow/tensorflow/blob/r2.1/tensorflow/python/ops/math_grad.py#L223.
13991401
@inlinable
14001402
internal func _vjpMinMaxHelper<T: TensorFlowFloatingPoint>(
14011403
_ x: Tensor<T>,
@@ -1554,7 +1556,7 @@ public extension Tensor where Scalar: Numeric & Comparable {
15541556
@inlinable
15551557
@differentiable(
15561558
wrt: self,
1557-
vjp: _vjpMinOrMax(squeezingAxes:) where Scalar: TensorFlowFloatingPoint)
1559+
vjp: _vjpMax(squeezingAxes:) where Scalar: TensorFlowFloatingPoint)
15581560
func max(squeezingAxes axes: Tensor<Int32>) -> Tensor {
15591561
return _Raw.max(self, reductionIndices: axes, keepDims: false)
15601562
}
@@ -1585,7 +1587,7 @@ public extension Tensor where Scalar: Numeric & Comparable {
15851587
@inlinable
15861588
@differentiable(
15871589
wrt: self,
1588-
vjp: _vjpMinOrMax(squeezingAxes:) where Scalar: TensorFlowFloatingPoint)
1590+
vjp: _vjpMin(squeezingAxes:) where Scalar: TensorFlowFloatingPoint)
15891591
func min(squeezingAxes axes: Tensor<Int32>) -> Tensor {
15901592
_Raw.min(self, reductionIndices: axes, keepDims: false)
15911593
}
@@ -1633,7 +1635,7 @@ public extension Tensor where Scalar: Numeric & Comparable {
16331635
/// - Parameter axes: The dimensions to reduce.
16341636
/// - Precondition: Each value in `axes` must be in the range `-rank..<rank`.
16351637
@inlinable
1636-
@differentiable(wrt: self, vjp: _vjpMinOrMax(alongAxes:) where Scalar: TensorFlowFloatingPoint)
1638+
@differentiable(wrt: self, vjp: _vjpMin(alongAxes:) where Scalar: TensorFlowFloatingPoint)
16371639
func min(alongAxes axes: Tensor<Int32>) -> Tensor {
16381640
_Raw.min(self, reductionIndices: axes, keepDims: true)
16391641
}
@@ -1665,7 +1667,7 @@ public extension Tensor where Scalar: Numeric & Comparable {
16651667
/// - Parameter axes: The dimensions to reduce.
16661668
/// - Precondition: Each value in `axes` must be in the range `-rank..<rank`.
16671669
@inlinable
1668-
@differentiable(wrt: self, vjp: _vjpMinOrMax(alongAxes:) where Scalar: TensorFlowFloatingPoint)
1670+
@differentiable(wrt: self, vjp: _vjpMax(alongAxes:) where Scalar: TensorFlowFloatingPoint)
16691671
func max(alongAxes axes: Tensor<Int32>) -> Tensor {
16701672
_Raw.max(self, reductionIndices: axes, keepDims: true)
16711673
}
@@ -1706,35 +1708,73 @@ public extension Tensor where Scalar: Numeric & Comparable {
17061708
}
17071709

17081710
internal extension Tensor where Scalar: TensorFlowFloatingPoint {
1709-
@inlinable
1710-
func _vjpMinOrMax(squeezingAxes axes: Tensor<Int32>) -> (Tensor, (Tensor) -> Tensor) {
1711-
let result = max(squeezingAxes: axes)
1712-
return (result, { v in
1713-
let yUnsqueezed = result.expandingShape(at: axes.scalars.map { Int($0) })
1714-
let gradientUnsqueezed = v.expandingShape(at: axes.scalars.map { Int($0) })
1711+
// Note: adapted from `_MinOrMaxGrad`:
1712+
// https://github.com/tensorflow/tensorflow/blob/r2.1/tensorflow/python/ops/math_grad.py#L223.
1713+
@inlinable
1714+
func _vjpMinMaxHelper(
1715+
squeezingAxes axes: Tensor<Int32>,
1716+
originalValue: Tensor,
1717+
seed: Tensor
1718+
) -> Tensor {
1719+
let yUnsqueezed = originalValue.expandingShape(at: axes.scalars.map { Int($0) })
1720+
let gradientUnsqueezed = seed.expandingShape(at: axes.scalars.map { Int($0) })
17151721

1716-
// Compute the number of selected (maximum or minimum) elements in each reduction dimension.
1717-
// If there are multiple minimum or maximum elements then the gradient will be divided between
1718-
// them.
1719-
let indicators = Tensor(yUnsqueezed .== self)
1720-
let selectedCount = indicators.sum(alongAxes: axes)
1722+
// Compute the number of selected (maximum or minimum) elements in each reduction dimension.
1723+
// If there are multiple minimum or maximum elements then the gradient will be divided
1724+
// between them.
1725+
let indicators = Tensor(yUnsqueezed .== self)
1726+
let selectedCount = indicators.sum(alongAxes: axes)
17211727

1722-
return gradientUnsqueezed.broadcasted(toShape: self.shapeTensor) * indicators / selectedCount
1723-
})
1724-
}
1728+
return gradientUnsqueezed.broadcasted(toShape: self.shapeTensor) * indicators / selectedCount
1729+
}
17251730

1726-
@inlinable
1727-
func _vjpMinOrMax(alongAxes axes: Tensor<Int32>) -> (Tensor, (Tensor) -> Tensor) {
1728-
let result = max(alongAxes: axes)
1729-
return (result, { v in
1730-
// Compute the number of selected (maximum or minimum) elements in each reduction dimension.
1731-
// If there are multiple minimum or maximum elements then the gradient will be divided between
1732-
// them.
1733-
let indicators = Tensor(result .== self)
1734-
let selectedCount = indicators.sum(alongAxes: axes)
1735-
return v.broadcasted(toShape: self.shapeTensor) * indicators / selectedCount
1736-
})
1737-
}
1731+
@inlinable
1732+
func _vjpMax(squeezingAxes axes: Tensor<Int32>) -> (Tensor, (Tensor) -> Tensor) {
1733+
let result = max(squeezingAxes: axes)
1734+
return (result, { v in
1735+
self._vjpMinMaxHelper(squeezingAxes: axes, originalValue: result, seed: v)
1736+
})
1737+
}
1738+
1739+
@inlinable
1740+
func _vjpMin(squeezingAxes axes: Tensor<Int32>) -> (Tensor, (Tensor) -> Tensor) {
1741+
let result = min(squeezingAxes: axes)
1742+
return (result, { v in
1743+
self._vjpMinMaxHelper(squeezingAxes: axes, originalValue: result, seed: v)
1744+
})
1745+
}
1746+
1747+
// Note: adapted from `_MinOrMaxGrad`:
1748+
// https://github.com/tensorflow/tensorflow/blob/r2.1/tensorflow/python/ops/math_grad.py#L223.
1749+
@inlinable
1750+
func _vjpMinMaxHelper(
1751+
alongAxes axes: Tensor<Int32>,
1752+
originalValue: Tensor,
1753+
seed: Tensor
1754+
) -> Tensor {
1755+
// Compute the number of selected (maximum or minimum) elements in each reduction dimension.
1756+
// If there are multiple minimum or maximum elements then the gradient will be divided
1757+
// between them.
1758+
let indicators = Tensor(originalValue .== self)
1759+
let selectedCount = indicators.sum(alongAxes: axes)
1760+
return seed.broadcasted(toShape: self.shapeTensor) * indicators / selectedCount
1761+
}
1762+
1763+
@inlinable
1764+
func _vjpMax(alongAxes axes: Tensor<Int32>) -> (Tensor, (Tensor) -> Tensor) {
1765+
let result = max(alongAxes: axes)
1766+
return (result, { v in
1767+
self._vjpMinMaxHelper(alongAxes: axes, originalValue: result, seed: v)
1768+
})
1769+
}
1770+
1771+
@inlinable
1772+
func _vjpMin(alongAxes axes: Tensor<Int32>) -> (Tensor, (Tensor) -> Tensor) {
1773+
let result = min(alongAxes: axes)
1774+
return (result, { v in
1775+
self._vjpMinMaxHelper(alongAxes: axes, originalValue: result, seed: v)
1776+
})
1777+
}
17381778
}
17391779

17401780
// MARK: - Numeric Reductions

Tests/TensorFlowTests/TensorAutoDiffTests.swift

Lines changed: 141 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -284,77 +284,187 @@ final class TensorAutoDiffTests: XCTestCase {
284284
XCTAssertEqual(varianceGradAlongAxes(input), expected)
285285
}
286286

287-
func testMin() {
288-
// The expected gradient values were computed using the following TensorFlow 2.0 Beta1
289-
// Python code with respective `a` and `b` tensors:
287+
func testMax() {
288+
// Expected gradient values were computed using the following TensorFlow Python code:
290289
// ```
290+
// import tensorflow as tf
291291
// with tf.GradientTape() as t:
292292
// t.watch([a, b])
293-
// y = tf.math.reduce_sum(tf.minimum(a, b))
293+
// y = tf.reduce_sum(tf.maximum(a, b))
294294
// print(t.gradient(y, [a, b]))
295295
// ```
296296
do {
297297
let a = Tensor<Float>([4, 5, 3])
298298
let b = Tensor<Float>([4, 2, 6])
299-
let computedGradient1 = gradient(at: a, b) { a, b in min(a, b).sum() }
300-
let expectedGradient1: (Tensor<Float>, Tensor<Float>) = (
301-
[1.0, 0.0, 1.0], [0.0, 1.0, 0.0])
299+
let computedGradient1 = gradient(at: a, b) { a, b in max(a, b).sum() }
300+
let expectedGradient1: (Tensor<Float>, Tensor<Float>) = ([1, 1, 0], [0, 0, 1])
302301
XCTAssertEqual(computedGradient1.0, expectedGradient1.0)
303302
XCTAssertEqual(computedGradient1.1, expectedGradient1.1)
304303

305-
let computedGradient2 = gradient(at: a, b) { a, b in min(b, a).sum() }
306-
let expectedGradient2: (Tensor<Float>, Tensor<Float>) = (
307-
[0.0, 0.0, 1.0], [1.0, 1.0, 0.0])
304+
let computedGradient2 = gradient(at: a, b) { a, b in max(b, a).sum() }
305+
let expectedGradient2: (Tensor<Float>, Tensor<Float>) = ([0, 1, 0], [1, 0, 1])
308306
XCTAssertEqual(computedGradient2.0, expectedGradient2.0)
309307
XCTAssertEqual(computedGradient2.1, expectedGradient2.1)
310308
}
311-
312309
do {
313-
let a = Tensor<Float>([[3.0, -2.0], [0.3, 10.0]])
314-
let b = Tensor<Float>([9.0, -3.0])
315-
let computedGradient = gradient(at: a, b) { a, b in min(a, b).sum() }
316-
let expectedGradient: (Tensor<Float>, Tensor<Float>) = (
317-
[[1.0, 0.0], [1.0, 0.0]], [0.0, 2.0])
310+
let a = Tensor<Float>([[3, -2], [0.3, 10]])
311+
let b = Tensor<Float>([9, -3])
312+
let computedGradient = gradient(at: a, b) { a, b in max(a, b).sum() }
313+
let expectedGradient: (Tensor<Float>, Tensor<Float>) = ([[0, 1], [0, 1]], [2, 0])
318314
XCTAssertEqual(computedGradient.0, expectedGradient.0)
319315
XCTAssertEqual(computedGradient.1, expectedGradient.1)
320316
}
321317
}
322318

323-
func testMax() {
324-
// The expected gradient values were computed using the following TensorFlow 2.0 Beta1
325-
// Python code with respective `a` and `b` tensors:
319+
func testMin() {
320+
// Expected gradient values were computed using the following TensorFlow Python code:
326321
// ```
322+
// import tensorflow as tf
327323
// with tf.GradientTape() as t:
328324
// t.watch([a, b])
329-
// y = tf.math.reduce_sum(tf.maximum(a, b))
325+
// y = tf.reduce_sum(tf.minimum(a, b))
330326
// print(t.gradient(y, [a, b]))
331327
// ```
332328
do {
333329
let a = Tensor<Float>([4, 5, 3])
334330
let b = Tensor<Float>([4, 2, 6])
335-
let computedGradient1 = gradient(at: a, b) { a, b in max(a, b).sum() }
336-
let expectedGradient1: (Tensor<Float>, Tensor<Float>) = (
337-
[1.0, 1.0, 0.0], [0.0, 0.0, 1.0])
331+
let computedGradient1 = gradient(at: a, b) { a, b in min(a, b).sum() }
332+
let expectedGradient1: (Tensor<Float>, Tensor<Float>) = ([1, 0, 1], [0, 1, 0])
338333
XCTAssertEqual(computedGradient1.0, expectedGradient1.0)
339334
XCTAssertEqual(computedGradient1.1, expectedGradient1.1)
340335

341-
let computedGradient2 = gradient(at: a, b) { a, b in max(b, a).sum() }
342-
let expectedGradient2: (Tensor<Float>, Tensor<Float>) = (
343-
[0.0, 1.0, 0.0], [1.0, 0.0, 1.0])
336+
let computedGradient2 = gradient(at: a, b) { a, b in min(b, a).sum() }
337+
let expectedGradient2: (Tensor<Float>, Tensor<Float>) = ([0, 0, 1], [1, 1, 0])
344338
XCTAssertEqual(computedGradient2.0, expectedGradient2.0)
345339
XCTAssertEqual(computedGradient2.1, expectedGradient2.1)
346340
}
341+
347342
do {
348-
let a = Tensor<Float>([[3.0, -2.0], [0.3, 10.0]])
349-
let b = Tensor<Float>([9.0, -3.0])
350-
let computedGradient = gradient(at: a, b) { a, b in max(a, b).sum() }
351-
let expectedGradient: (Tensor<Float>, Tensor<Float>) = (
352-
[[0.0, 1.0], [0.0, 1.0]], [2.0, 0.0])
343+
let a = Tensor<Float>([[3, -2], [0.3, 10]])
344+
let b = Tensor<Float>([9, -3])
345+
let computedGradient = gradient(at: a, b) { a, b in min(a, b).sum() }
346+
let expectedGradient: (Tensor<Float>, Tensor<Float>) = ([[1, 0], [1, 0]], [0, 2])
353347
XCTAssertEqual(computedGradient.0, expectedGradient.0)
354348
XCTAssertEqual(computedGradient.1, expectedGradient.1)
355349
}
356350
}
357351

352+
func testMaxAlongAxes() {
353+
// Expected gradient values were computed using the following TensorFlow Python code:
354+
// ```
355+
// import tensorflow as tf
356+
// x = tf.constant(range(6), shape=(2, 3), dtype=float)
357+
// with tf.GradientTape() as t:
358+
// t.watch(x)
359+
// y = tf.reduce_sum(tf.reduce_max(x, axis=0, keepdims=True))
360+
// print(t.gradient(y, x))
361+
// ```
362+
func maxAlongAxesSum(_ x: Tensor<Float>) -> Tensor<Float> {
363+
x.max(alongAxes: 0).sum()
364+
}
365+
do {
366+
let x: Tensor<Float> = [[0, 1, 2], [3, 4, 5]]
367+
let (value, computedGradient) = valueWithGradient(at: x, in: maxAlongAxesSum)
368+
XCTAssertEqual(value, maxAlongAxesSum(x))
369+
let expectedGradient: Tensor<Float> = [[0, 0, 0], [1, 1, 1]]
370+
XCTAssertEqual(computedGradient, expectedGradient)
371+
}
372+
do {
373+
let x: Tensor<Float> = [[0, 1, 2], [2, 1, 0]]
374+
let (value, computedGradient) = valueWithGradient(at: x, in: maxAlongAxesSum)
375+
XCTAssertEqual(value, maxAlongAxesSum(x))
376+
let expectedGradient: Tensor<Float> = [[0, 0.5, 1], [1, 0.5, 0]]
377+
XCTAssertEqual(computedGradient, expectedGradient)
378+
}
379+
}
380+
381+
func testMinAlongAxes() {
382+
// Expected gradient values were computed using the following TensorFlow Python code:
383+
// ```
384+
// import tensorflow as tf
385+
// x = tf.constant(range(6), shape=(2, 3), dtype=float)
386+
// with tf.GradientTape() as t:
387+
// t.watch(x)
388+
// y = tf.reduce_sum(tf.reduce_min(x, axis=0, keepdims=True))
389+
// print(t.gradient(y, x))
390+
// ```
391+
func minAlongAxesSum(_ x: Tensor<Float>) -> Tensor<Float> {
392+
x.min(alongAxes: 0).sum()
393+
}
394+
do {
395+
let x: Tensor<Float> = [[0, 1, 2], [3, 4, 5]]
396+
let (value, computedGradient) = valueWithGradient(at: x, in: minAlongAxesSum)
397+
XCTAssertEqual(value, minAlongAxesSum(x))
398+
let expectedGradient: Tensor<Float> = [[1, 1, 1], [0, 0, 0]]
399+
XCTAssertEqual(computedGradient, expectedGradient)
400+
}
401+
do {
402+
let x: Tensor<Float> = [[0, 1, 2], [2, 1, 0]]
403+
let (value, computedGradient) = valueWithGradient(at: x, in: minAlongAxesSum)
404+
XCTAssertEqual(value, minAlongAxesSum(x))
405+
let expectedGradient: Tensor<Float> = [[1, 0.5, 0], [0, 0.5, 1]]
406+
XCTAssertEqual(computedGradient, expectedGradient)
407+
}
408+
}
409+
410+
func testMaxSqueezingAxes() {
411+
// Expected gradient values were computed using the following TensorFlow Python code:
412+
// ```
413+
// import tensorflow as tf
414+
// x = tf.constant(range(6), shape=(2, 3), dtype=float)
415+
// with tf.GradientTape() as t:
416+
// t.watch(x)
417+
// y = tf.reduce_sum(tf.reduce_max(x, axis=0, keepdims=False))
418+
// print(t.gradient(y, x))
419+
// ```
420+
func maxSqueezingAxesSum(_ x: Tensor<Float>) -> Tensor<Float> {
421+
x.max(squeezingAxes: 0).sum()
422+
}
423+
do {
424+
let x: Tensor<Float> = [[0, 1, 2], [3, 4, 5]]
425+
let (value, computedGradient) = valueWithGradient(at: x, in: maxSqueezingAxesSum)
426+
XCTAssertEqual(value, maxSqueezingAxesSum(x))
427+
let expectedGradient: Tensor<Float> = [[0, 0, 0], [1, 1, 1]]
428+
XCTAssertEqual(computedGradient, expectedGradient)
429+
}
430+
do {
431+
let x: Tensor<Float> = [[0, 1, 2], [2, 1, 0]]
432+
let (value, computedGradient) = valueWithGradient(at: x, in: maxSqueezingAxesSum)
433+
XCTAssertEqual(value, maxSqueezingAxesSum(x))
434+
let expectedGradient: Tensor<Float> = [[0, 0.5, 1], [1, 0.5, 0]]
435+
XCTAssertEqual(computedGradient, expectedGradient)
436+
}
437+
}
438+
439+
func testMinSqueezingAxes() {
440+
// Expected gradient values were computed using the following TensorFlow Python code:
441+
// ```
442+
// import tensorflow as tf
443+
// x = tf.constant(range(6), shape=(2, 3), dtype=float)
444+
// with tf.GradientTape() as t:
445+
// t.watch(x)
446+
// y = tf.reduce_sum(tf.reduce_min(x, axis=0, keepdims=False))
447+
// print(t.gradient(y, x))
448+
// ```
449+
func minSqueezingAxesSum(_ x: Tensor<Float>) -> Tensor<Float> {
450+
x.min(squeezingAxes: 0).sum()
451+
}
452+
do {
453+
let x: Tensor<Float> = [[0, 1, 2], [3, 4, 5]]
454+
let (value, computedGradient) = valueWithGradient(at: x, in: minSqueezingAxesSum)
455+
XCTAssertEqual(value, minSqueezingAxesSum(x))
456+
let expectedGradient: Tensor<Float> = [[1, 1, 1], [0, 0, 0]]
457+
XCTAssertEqual(computedGradient, expectedGradient)
458+
}
459+
do {
460+
let x: Tensor<Float> = [[0, 1, 2], [2, 1, 0]]
461+
let (value, computedGradient) = valueWithGradient(at: x, in: minSqueezingAxesSum)
462+
XCTAssertEqual(value, minSqueezingAxesSum(x))
463+
let expectedGradient: Tensor<Float> = [[1, 0.5, 0], [0, 0.5, 1]]
464+
XCTAssertEqual(computedGradient, expectedGradient)
465+
}
466+
}
467+
358468
func testTensorInitStacking() {
359469
let a1 = Tensor<Float>([1, 2, 3, 4, 5])
360470
let b1 = Tensor<Float>([6, 7, 8, 9, 10])

0 commit comments

Comments
 (0)