@@ -1963,6 +1963,79 @@ public extension Tensor where Scalar: Numeric {
1963
1963
) -> Tensor {
1964
1964
Raw . cumsum ( self , axis: axis, exclusive: exclusive, reverse: reverse)
1965
1965
}
1966
+
1967
+ /// Returns the cumulative product of this tensor along the specified axis. By default, this
1968
+ /// function performs an inclusive cumulative product which means that the first element of the
1969
+ /// input is identical to the first element of the output:
1970
+ /// ```
1971
+ /// Tensor<Float>([a, b, c]).cumulativeProduct() = Tensor<Float>([a, a * b, a * b * c])
1972
+ /// ```
1973
+ /// By setting the `exclusive` argument to `true`, an exclusive cumulative product is performed
1974
+ /// instead:
1975
+ /// ```
1976
+ /// Tensor<Float>([a, b, c]).cumulativeProduct(exclusive: true) = Tensor<Float>([1, a, a * b])
1977
+ /// ```
1978
+ /// By setting the `reverse` argument to `true`, the cumulative product is performed in the
1979
+ /// opposite direction:
1980
+ /// ```
1981
+ /// Tensor<Float>([a, b, c]).cumulativeProduct(reverse: true) ==
1982
+ /// Tensor<Float>([a * b * c, a * b, a])
1983
+ /// ```
1984
+ /// This is more efficient than separately reversing the resulting tensor.
1985
+ ///
1986
+ /// - Parameters:
1987
+ /// - axis: Axis along which to perform the cumulative product operation.
1988
+ /// - exclusive: Indicates whether to perform an exclusive cumulative product.
1989
+ /// - reverse: Indicates whether to perform the cumulative product in reversed order.
1990
+ /// - Returns: Result of the cumulative product operation.
1991
+ /// - Precondition: `axis` must be in the range `-rank..<rank`.
1992
+ @inlinable
1993
+ @differentiable ( wrt: self where Scalar: TensorFlowFloatingPoint)
1994
+ func cumulativeProduct(
1995
+ alongAxis axis: Int ,
1996
+ exclusive: Bool = false ,
1997
+ reverse: Bool = false
1998
+ ) -> Tensor {
1999
+ cumulativeProduct (
2000
+ alongAxis: Tensor < Int32 > ( Int32 ( axis) ) ,
2001
+ exclusive: exclusive,
2002
+ reverse: reverse)
2003
+ }
2004
+
2005
+ /// Returns the cumulative product of this tensor along the specified axis. By default, this
2006
+ /// function performs an inclusive cumulative product which means that the first element of the
2007
+ /// input is identical to the first element of the output:
2008
+ /// ```
2009
+ /// Tensor<Float>([a, b, c]).cumulativeProduct() = Tensor<Float>([a, a * b, a * b * c])
2010
+ /// ```
2011
+ /// By setting the `exclusive` argument to `true`, an exclusive cumulative product is performed
2012
+ /// instead:
2013
+ /// ```
2014
+ /// Tensor<Float>([a, b, c]).cumulativeProduct(exclusive: true) = Tensor<Float>([1, a, a * b])
2015
+ /// ```
2016
+ /// By setting the `reverse` argument to `true`, the cumulative product is performed in the
2017
+ /// opposite direction:
2018
+ /// ```
2019
+ /// Tensor<Float>([a, b, c]).cumulativeProduct(reverse: true) ==
2020
+ /// Tensor<Float>([a * b * c, a * b, a])
2021
+ /// ```
2022
+ /// This is more efficient than separately reversing the resulting tensor.
2023
+ ///
2024
+ /// - Parameters:
2025
+ /// - axis: Axis along which to perform the cumulative product operation.
2026
+ /// - exclusive: Indicates whether to perform an exclusive cumulative product.
2027
+ /// - reverse: Indicates whether to perform the cumulative product in reversed order.
2028
+ /// - Returns: Result of the cumulative product operation.
2029
+ /// - Precondition: `axis` must be in the range `-rank..<rank`.
2030
+ @inlinable
2031
+ @differentiable ( wrt: self , vjp: _vjpCumulativeProduct where Scalar: TensorFlowFloatingPoint)
2032
+ func cumulativeProduct(
2033
+ alongAxis axis: Tensor < Int32 > ,
2034
+ exclusive: Bool = false ,
2035
+ reverse: Bool = false
2036
+ ) -> Tensor {
2037
+ Raw . cumprod ( self , axis: axis, exclusive: exclusive, reverse: reverse)
2038
+ }
1966
2039
}
1967
2040
1968
2041
internal extension Tensor where Scalar: TensorFlowFloatingPoint {
@@ -2008,6 +2081,22 @@ internal extension Tensor where Scalar: TensorFlowFloatingPoint {
2008
2081
v. cumulativeSum ( alongAxis: axis, exclusive: exclusive, reverse: !reverse)
2009
2082
} )
2010
2083
}
2084
+
2085
+ @inlinable
2086
+ func _vjpCumulativeProduct(
2087
+ alongAxis axis: Tensor < Int32 > ,
2088
+ exclusive: Bool = false ,
2089
+ reverse: Bool = false
2090
+ ) -> ( Tensor , ( Tensor ) -> Tensor ) {
2091
+ let result = cumulativeProduct ( alongAxis: axis, exclusive: exclusive, reverse: reverse)
2092
+ return ( result, { v in
2093
+ ( result * v) . cumulativeSum (
2094
+ alongAxis: axis,
2095
+ exclusive: exclusive,
2096
+ reverse: !reverse
2097
+ ) / self
2098
+ } )
2099
+ }
2011
2100
}
2012
2101
2013
2102
// TODO: Consider making the return type be generic over `FloatingPoint` types
0 commit comments