Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 1f5e7fa

Browse files
authored
Added support for 'Tensor.cumulativeProduct(alongAxis:exclusive:reverse:)'. (#358)
1 parent 3b9af2f commit 1f5e7fa

File tree

2 files changed

+117
-0
lines changed

2 files changed

+117
-0
lines changed

Sources/TensorFlow/Operators/Math.swift

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1963,6 +1963,79 @@ public extension Tensor where Scalar: Numeric {
19631963
) -> Tensor {
19641964
Raw.cumsum(self, axis: axis, exclusive: exclusive, reverse: reverse)
19651965
}
1966+
1967+
/// Returns the cumulative product of this tensor along the specified axis. By default, this
1968+
/// function performs an inclusive cumulative product which means that the first element of the
1969+
/// input is identical to the first element of the output:
1970+
/// ```
1971+
/// Tensor<Float>([a, b, c]).cumulativeProduct() = Tensor<Float>([a, a * b, a * b * c])
1972+
/// ```
1973+
/// By setting the `exclusive` argument to `true`, an exclusive cumulative product is performed
1974+
/// instead:
1975+
/// ```
1976+
/// Tensor<Float>([a, b, c]).cumulativeProduct(exclusive: true) = Tensor<Float>([1, a, a * b])
1977+
/// ```
1978+
/// By setting the `reverse` argument to `true`, the cumulative product is performed in the
1979+
/// opposite direction:
1980+
/// ```
1981+
/// Tensor<Float>([a, b, c]).cumulativeProduct(reverse: true) ==
1982+
/// Tensor<Float>([a * b * c, a * b, a])
1983+
/// ```
1984+
/// This is more efficient than separately reversing the resulting tensor.
1985+
///
1986+
/// - Parameters:
1987+
/// - axis: Axis along which to perform the cumulative product operation.
1988+
/// - exclusive: Indicates whether to perform an exclusive cumulative product.
1989+
/// - reverse: Indicates whether to perform the cumulative product in reversed order.
1990+
/// - Returns: Result of the cumulative product operation.
1991+
/// - Precondition: `axis` must be in the range `-rank..<rank`.
1992+
@inlinable
1993+
@differentiable(wrt: self where Scalar: TensorFlowFloatingPoint)
1994+
func cumulativeProduct(
1995+
alongAxis axis: Int,
1996+
exclusive: Bool = false,
1997+
reverse: Bool = false
1998+
) -> Tensor {
1999+
cumulativeProduct(
2000+
alongAxis: Tensor<Int32>(Int32(axis)),
2001+
exclusive: exclusive,
2002+
reverse: reverse)
2003+
}
2004+
2005+
/// Returns the cumulative product of this tensor along the specified axis. By default, this
2006+
/// function performs an inclusive cumulative product which means that the first element of the
2007+
/// input is identical to the first element of the output:
2008+
/// ```
2009+
/// Tensor<Float>([a, b, c]).cumulativeProduct() = Tensor<Float>([a, a * b, a * b * c])
2010+
/// ```
2011+
/// By setting the `exclusive` argument to `true`, an exclusive cumulative product is performed
2012+
/// instead:
2013+
/// ```
2014+
/// Tensor<Float>([a, b, c]).cumulativeProduct(exclusive: true) = Tensor<Float>([1, a, a * b])
2015+
/// ```
2016+
/// By setting the `reverse` argument to `true`, the cumulative product is performed in the
2017+
/// opposite direction:
2018+
/// ```
2019+
/// Tensor<Float>([a, b, c]).cumulativeProduct(reverse: true) ==
2020+
/// Tensor<Float>([a * b * c, a * b, a])
2021+
/// ```
2022+
/// This is more efficient than separately reversing the resulting tensor.
2023+
///
2024+
/// - Parameters:
2025+
/// - axis: Axis along which to perform the cumulative product operation.
2026+
/// - exclusive: Indicates whether to perform an exclusive cumulative product.
2027+
/// - reverse: Indicates whether to perform the cumulative product in reversed order.
2028+
/// - Returns: Result of the cumulative product operation.
2029+
/// - Precondition: `axis` must be in the range `-rank..<rank`.
2030+
@inlinable
2031+
@differentiable(wrt: self, vjp: _vjpCumulativeProduct where Scalar: TensorFlowFloatingPoint)
2032+
func cumulativeProduct(
2033+
alongAxis axis: Tensor<Int32>,
2034+
exclusive: Bool = false,
2035+
reverse: Bool = false
2036+
) -> Tensor {
2037+
Raw.cumprod(self, axis: axis, exclusive: exclusive, reverse: reverse)
2038+
}
19662039
}
19672040

19682041
internal extension Tensor where Scalar: TensorFlowFloatingPoint {
@@ -2008,6 +2081,22 @@ internal extension Tensor where Scalar: TensorFlowFloatingPoint {
20082081
v.cumulativeSum(alongAxis: axis, exclusive: exclusive, reverse: !reverse)
20092082
})
20102083
}
2084+
2085+
@inlinable
2086+
func _vjpCumulativeProduct(
2087+
alongAxis axis: Tensor<Int32>,
2088+
exclusive: Bool = false,
2089+
reverse: Bool = false
2090+
) -> (Tensor, (Tensor) -> Tensor) {
2091+
let result = cumulativeProduct(alongAxis: axis, exclusive: exclusive, reverse: reverse)
2092+
return (result, { v in
2093+
(result * v).cumulativeSum(
2094+
alongAxis: axis,
2095+
exclusive: exclusive,
2096+
reverse: !reverse
2097+
) / self
2098+
})
2099+
}
20112100
}
20122101

20132102
// TODO: Consider making the return type be generic over `FloatingPoint` types

Tests/TensorFlowTests/OperatorTests/MathTests.swift

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,33 @@ final class MathOperatorTests: XCTestCase {
242242
XCTAssertEqual(reverseExclusiveCumsum1, Tensor<Float>([[3, 2, 0], [9, 5, 0]]))
243243
}
244244

245+
func testCumulativeProduct() {
246+
// 2 x 3
247+
let x = Tensor<Float>([[0, 1, 2], [3, 4, 5]])
248+
let cumprod0 = x.cumulativeProduct(alongAxis: 0)
249+
let cumprod1 = x.cumulativeProduct(alongAxis: 1)
250+
let exclusiveCumprod0 = x.cumulativeProduct(alongAxis: 0, exclusive: true)
251+
let exclusiveCumprod1 = x.cumulativeProduct(alongAxis: 1, exclusive: true)
252+
let reverseCumprod0 = x.cumulativeProduct(alongAxis: 0, reverse: true)
253+
let reverseCumprod1 = x.cumulativeProduct(alongAxis: 1, reverse: true)
254+
let reverseExclusiveCumprod0 = x.cumulativeProduct(
255+
alongAxis: 0,
256+
exclusive: true,
257+
reverse: true)
258+
let reverseExclusiveCumprod1 = x.cumulativeProduct(
259+
alongAxis: 1,
260+
exclusive: true,
261+
reverse: true)
262+
XCTAssertEqual(cumprod0, Tensor<Float>([[0, 1, 2], [0, 4, 10]]))
263+
XCTAssertEqual(cumprod1, Tensor<Float>([[0, 0, 0], [3, 12, 60]]))
264+
XCTAssertEqual(exclusiveCumprod0, Tensor<Float>([[1, 1, 1], [0, 1, 2]]))
265+
XCTAssertEqual(exclusiveCumprod1, Tensor<Float>([[1, 0, 0], [1, 3, 12]]))
266+
XCTAssertEqual(reverseCumprod0, Tensor<Float>([[0, 4, 10], [3, 4, 5]]))
267+
XCTAssertEqual(reverseCumprod1, Tensor<Float>([[0, 2, 2], [60, 20, 5]]))
268+
XCTAssertEqual(reverseExclusiveCumprod0, Tensor<Float>([[3, 4, 5], [1, 1, 1]]))
269+
XCTAssertEqual(reverseExclusiveCumprod1, Tensor<Float>([[2, 2, 1], [20, 5, 1]]))
270+
}
271+
245272
func testStandardDeviation() {
246273
XCTAssertEqual(Tensor<Float>([1]).standardDeviation(), Tensor(0))
247274
XCTAssertEqual(Tensor<Float>([0, 1]).standardDeviation(alongAxes: 0), Tensor(0.5))
@@ -485,6 +512,7 @@ final class MathOperatorTests: XCTestCase {
485512
("testArgmax", testArgmax),
486513
("testReduction", testReduction),
487514
("testCumulativeSum", testCumulativeSum),
515+
("testCumulativeProduct", testCumulativeProduct),
488516
("testStandardDeviation", testStandardDeviation),
489517
("testLogSumExp", testLogSumExp),
490518
("testMoments", testMoments),

0 commit comments

Comments
 (0)