Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 4f193c1

Browse files
kongziidan-zheng
authored andcommitted
Make Tensor.product(squeezingAxes:) differentiable (#550)
1 parent 35dfddf commit 4f193c1

File tree

2 files changed

+92
-2
lines changed

2 files changed

+92
-2
lines changed

Sources/TensorFlow/Operators/Math.swift

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1812,8 +1812,8 @@ public extension Tensor where Scalar: Numeric {
18121812
///
18131813
/// - Parameter axes: The dimensions to reduce.
18141814
/// - Precondition: Each value in `axes` must be in the range `-rank...rank`.
1815-
// TODO: Make this @differentiable.
18161815
@inlinable
1816+
@differentiable(wrt: self, vjp: _vjpProduct(squeezingAxes:) where Scalar: TensorFlowFloatingPoint)
18171817
func product(squeezingAxes axes: Tensor<Int32>) -> Tensor {
18181818
_Raw.prod(self, reductionIndices: axes, keepDims: false)
18191819
}
@@ -1823,6 +1823,7 @@ public extension Tensor where Scalar: Numeric {
18231823
/// - Parameter axes: The dimensions to reduce.
18241824
/// - Precondition: Each value in `axes` must be in the range `-rank...rank`.
18251825
@inlinable
1826+
@differentiable(wrt: self where Scalar: TensorFlowFloatingPoint)
18261827
func product(squeezingAxes axes: [Int]) -> Tensor {
18271828
// TODO(TF-433): Remove workaround for differentiating `map`.
18281829
let axes = {axes.map(Int32.init)}()
@@ -1834,11 +1835,13 @@ public extension Tensor where Scalar: Numeric {
18341835
/// - Parameter axes: The dimensions to reduce.
18351836
/// - Precondition: Each value in `axes` must be in the range `-rank...rank`.
18361837
@inlinable
1838+
@differentiable(wrt: self where Scalar: TensorFlowFloatingPoint)
18371839
func product(squeezingAxes axes: Int...) -> Tensor {
18381840
product(squeezingAxes: axes)
18391841
}
18401842

18411843
@inlinable
1844+
@differentiable(wrt: self where Scalar: TensorFlowFloatingPoint)
18421845
func product() -> Tensor {
18431846
flattened().product(squeezingAxes: 0)
18441847
}
@@ -2224,6 +2227,54 @@ internal extension Tensor where Scalar: TensorFlowFloatingPoint {
22242227
) / self
22252228
})
22262229
}
2230+
2231+
// Adapted from `_ProdGrad` in Python TensorFlow:
2232+
// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/math_grad.py
2233+
@inlinable
2234+
func _vjpProduct(squeezingAxes axes: Tensor<Int32>) -> (Tensor, (Tensor) -> Tensor) {
2235+
// The gradient can be expressed by dividing the product by each entry of the
2236+
// input tensor, but this approach can't deal with zeros in the input.
2237+
// Here, we avoid this problem by composing the output as a product of two
2238+
// `cumulativeProduct` operations.
2239+
let result = product(squeezingAxes: axes)
2240+
return (result, { v in
2241+
// Reshape reduction indices for the case where the parameter is a scalar.
2242+
var reductionIndices = axes.reshaped(to: TensorShape(-1))
2243+
// Normalize any negative reduction indices to positive values.
2244+
reductionIndices = (reductionIndices + Int32(self.rank)) % Int32(self.rank)
2245+
2246+
// Expand `v` to full input shape.
2247+
var outputShape = self.shape
2248+
for axis in reductionIndices.scalars {
2249+
outputShape[Int(axis)] = 1
2250+
}
2251+
let vReshaped = v.reshaped(to: outputShape)
2252+
let vBroadcasted = vReshaped.broadcasted(to: self.shape)
2253+
2254+
// Pack all reduced dimensions into a single one, so we can perform the
2255+
// `cumulativeProduct` operations.
2256+
let idx = Tensor<Int32>(0..<Int32(self.rank))
2257+
let other = Tensor<Int32>(
2258+
Array(Set(idx.scalars).symmetricDifference(reductionIndices.scalars)))
2259+
let perm = reductionIndices.concatenated(with: other)
2260+
let reducedNum = Int(
2261+
self.shapeTensor.gathering(atIndices: reductionIndices).product().scalarized())
2262+
let otherNum = Int(
2263+
self.shapeTensor.gathering(atIndices: other).product().scalarized())
2264+
2265+
let permuted = self.transposed(permutation: perm)
2266+
let reshaped = permuted.reshaped(to: [reducedNum, otherNum])
2267+
// Calculate product, leaving out the current entry.
2268+
let left = reshaped.cumulativeProduct(alongAxis: 0, exclusive: true, reverse: false)
2269+
let right = reshaped.cumulativeProduct(alongAxis: 0, exclusive: true, reverse: true)
2270+
let y = (left * right).reshaped(to: permuted.shape)
2271+
2272+
// Invert the transpose and reshape operations.
2273+
// Make sure to set the statically known shape information through a reshape.
2274+
return (vBroadcasted * y.transposed(permutation: _Raw.invertPermutation(perm)))
2275+
.reshaped(to: self.shape)
2276+
})
2277+
}
22272278
}
22282279

22292280
// TODO: Consider making the return type be generic over `FloatingPoint` types

Tests/TensorFlowTests/TensorAutoDiffTests.swift

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -525,6 +525,44 @@ final class TensorAutoDiffTests: XCTestCase {
525525
assertEqual(computedGradient, expectedGradient, accuracy: 0.0001)
526526
}
527527

528+
func testProductGrad() {
529+
// The expected gradient values were computed using the following Python code:
530+
// ```
531+
// import tensorflow as tf
532+
// # Adjust values of `x` and `axis` for each test.
533+
// x = tf.constant([[[3, 4], [5, 6], [7, 8]], [[3, 5], [0, 6], [5, 6]]], dtype=tf.float32)
534+
// axis = 1
535+
// with tf.GradientTape() as t:
536+
// t.watch(x)
537+
// y = tf.reduce_prod(x, axis=axis)
538+
// z = tf.reduce_sum(y)
539+
// print(t.gradient(z, x))
540+
// ```
541+
func product(_ x: Tensor<Float>) -> Tensor<Float> {
542+
return x.product().sum()
543+
}
544+
func productSqueezingAxes1(_ x: Tensor<Float>) -> Tensor<Float> {
545+
return x.product(squeezingAxes: 1).sum()
546+
}
547+
func productSqueezingAxes_Neg1(_ x: Tensor<Float>) -> Tensor<Float> {
548+
return x.product(squeezingAxes: -1).sum()
549+
}
550+
func productSqueezingAxes01(_ x: Tensor<Float>) -> Tensor<Float> {
551+
return x.product(squeezingAxes: [0, 1]).sum()
552+
}
553+
XCTAssertEqual(gradient(at: [[10], [20]], in: product), [[20], [10]])
554+
XCTAssertEqual(gradient(at: [[10, 20], [20, 30]], in: productSqueezingAxes1),
555+
[[20, 10], [30, 20]])
556+
XCTAssertEqual(gradient(at: [[10, 20], [20, 30]], in: productSqueezingAxes_Neg1),
557+
[[20, 10], [30, 20]])
558+
XCTAssertEqual(gradient(at: [[[3, 4], [5, 6], [7, 8]], [[3, 5], [0, 6], [5, 6]]],
559+
in: productSqueezingAxes1),
560+
[[[35, 48], [21, 32], [15, 24]], [[0, 36], [15, 30], [0, 30]]])
561+
XCTAssertEqual(gradient(at: [[[3, 4], [5, 6], [7, 8]], [[3, 5], [0, 6], [5, 6]]],
562+
in: productSqueezingAxes01),
563+
[[[0, 8640], [0, 5760], [0, 4320]], [[0, 6912], [1575, 5760], [0, 5760]]])
564+
}
565+
528566
static var allTests = [
529567
("testSimpleGrad", testSimpleGrad),
530568
("testGenericGrad", testGenericGrad),
@@ -569,6 +607,7 @@ final class TensorAutoDiffTests: XCTestCase {
569607
("testUnbroadcastToShape", testUnbroadcastToShape),
570608
("testUnbroadcastTo", testUnbroadcastTo),
571609
("testUnbroadcastLike", testUnbroadcastLike),
572-
("testBatchNormalized", testBatchNormalized)
610+
("testBatchNormalized", testBatchNormalized),
611+
("testProductGrad", testProductGrad),
573612
]
574613
}

0 commit comments

Comments
 (0)