Skip to content

Commit e7f5dad

Browse files
committed
ENH: add cumulative_prod
1 parent 2eafb97 commit e7f5dad

File tree

5 files changed

+35
-1
lines changed

5 files changed

+35
-1
lines changed

array_api_compat/common/_aliases.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,36 @@ def cumulative_sum(
292292
)
293293
return res
294294

295+
296+
def cumulative_prod(
297+
x: ndarray,
298+
/,
299+
xp,
300+
*,
301+
axis: Optional[int] = None,
302+
dtype: Optional[Dtype] = None,
303+
include_initial: bool = False,
304+
**kwargs
305+
) -> ndarray:
306+
wrapped_xp = array_namespace(x)
307+
308+
if axis is None:
309+
if x.ndim > 1:
310+
raise ValueError("axis must be specified in cumulative_prod for more than one dimension")
311+
axis = 0
312+
313+
res = xp.cumprod(x, axis=axis, dtype=dtype, **kwargs)
314+
315+
# np.cumprod does not support include_initial
316+
if include_initial:
317+
initial_shape = list(x.shape)
318+
initial_shape[axis] = 1
319+
res = xp.concatenate(
320+
[wrapped_xp.ones(shape=initial_shape, dtype=res.dtype, device=device(res)), res],
321+
axis=axis,
322+
)
323+
return res
324+
295325
# The min and max argument names in clip are different and not optional in numpy, and type
296326
# promotion behavior is different.
297327
def clip(
@@ -544,7 +574,7 @@ def sign(x: ndarray, /, xp, **kwargs) -> ndarray:
544574
'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like',
545575
'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult',
546576
'unique_all', 'unique_counts', 'unique_inverse', 'unique_values',
547-
'std', 'var', 'cumulative_sum', 'clip', 'permute_dims',
577+
'astype', 'std', 'var', 'cumulative_sum', 'cumulative_prod','clip', 'permute_dims',
548578
'reshape', 'argsort', 'sort', 'nonzero', 'ceil', 'floor', 'trunc',
549579
'matmul', 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype',
550580
'unstack', 'sign']

array_api_compat/cupy/_aliases.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
std = get_xp(cp)(_aliases.std)
5050
var = get_xp(cp)(_aliases.var)
5151
cumulative_sum = get_xp(cp)(_aliases.cumulative_sum)
52+
cumulative_prod = get_xp(cp)(_aliases.cumulative_prod)
5253
clip = get_xp(cp)(_aliases.clip)
5354
permute_dims = get_xp(cp)(_aliases.permute_dims)
5455
reshape = get_xp(cp)(_aliases.reshape)

array_api_compat/dask/array/_aliases.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ def arange(
121121
std = get_xp(da)(_aliases.std)
122122
var = get_xp(da)(_aliases.var)
123123
cumulative_sum = get_xp(da)(_aliases.cumulative_sum)
124+
cumulative_prod = get_xp(da)(_aliases.cumulative_prod)
124125
empty = get_xp(da)(_aliases.empty)
125126
empty_like = get_xp(da)(_aliases.empty_like)
126127
full = get_xp(da)(_aliases.full)

array_api_compat/numpy/_aliases.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
std = get_xp(np)(_aliases.std)
5050
var = get_xp(np)(_aliases.var)
5151
cumulative_sum = get_xp(np)(_aliases.cumulative_sum)
52+
cumulative_prod = get_xp(np)(_aliases.cumulative_prod)
5253
clip = get_xp(np)(_aliases.clip)
5354
permute_dims = get_xp(np)(_aliases.permute_dims)
5455
reshape = get_xp(np)(_aliases.reshape)

array_api_compat/torch/_aliases.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,7 @@ def min(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keep
210210
clip = get_xp(torch)(_aliases_clip)
211211
unstack = get_xp(torch)(_aliases_unstack)
212212
cumulative_sum = get_xp(torch)(_aliases_cumulative_sum)
213+
cumulative_prod = get_xp(torch)(_aliases_cumulative_prod)
213214

214215
# torch.sort also returns a tuple
215216
# https://github.com/pytorch/pytorch/issues/70921

0 commit comments

Comments
 (0)