Skip to content

Commit 0b36f2d

Browse files
committed
MAINT: cumsum/cumprod
1 parent cebf600 commit 0b36f2d

File tree

4 files changed

+32
-14
lines changed

4 files changed

+32
-14
lines changed

torch_np/_detail/_reductions.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,22 @@ def wrapped(tensor, axis=None, keepdims=NoValue, *args, **kwds):
4848
return result
4949
return wrapped
5050

51+
52+
def deco_axis_ravel(func):
53+
"""Generically handle 'axis=None ravels' behavior."""
54+
@functools.wraps(func)
55+
def wrapped(tensor, axis, *args, **kwds):
56+
if axis is not None:
57+
axis = _util.normalize_axis_index(axis, tensor.ndim)
58+
59+
tensors, axis = _util.axis_none_ravel(tensor, axis=axis) # XXX: inline
60+
tensor = tensors[0]
61+
62+
result = func(tensor, axis=axis, *args, **kwds)
63+
return result
64+
return wrapped
65+
66+
5167
##################################3
5268

5369

@@ -252,6 +268,7 @@ def var(tensor, axis=None, dtype=None, ddof=0, *, where=NoValue):
252268
# 2. axis=None ravels (cf concatenate)
253269

254270

271+
@deco_axis_ravel
255272
def cumprod(tensor, axis, dtype=None):
256273
if dtype == torch.bool:
257274
dtype = _dtypes_impl.default_int_dtype
@@ -263,6 +280,7 @@ def cumprod(tensor, axis, dtype=None):
263280
return result
264281

265282

283+
@deco_axis_ravel
266284
def cumsum(tensor, axis, dtype=None):
267285
if dtype == torch.bool:
268286
dtype = _dtypes_impl.default_int_dtype

torch_np/_detail/_util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def apply_keepdims(tensor, axis, ndim):
128128

129129
def axis_none_ravel(*tensors, axis=None):
130130
"""Ravel the arrays if axis is none."""
131-
# XXX: is only used at `concatenate`. Inline unless reused more widely
131+
# XXX: is only used at `concatenate` and cumsum/cumprod. Inline unless reused more widely
132132
if axis is None:
133133
tensors = tuple(ar.ravel() for ar in tensors)
134134
return tensors, 0

torch_np/_funcs.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -448,3 +448,14 @@ def count_nonzero(a: ArrayLike, axis: AxisLike=None, *, keepdims=False):
448448
result = _reductions.count_nonzero(a, axis=axis, keepdims=keepdims)
449449
return _helpers.array_from(result)
450450

451+
452+
@normalizer
453+
def cumsum(a: ArrayLike, axis: AxisLike = None, dtype: DTypeLike = None, out=None):
454+
result = _reductions.cumsum(a, axis=axis, dtype=dtype)
455+
return _helpers.result_or_out(result, out)
456+
457+
458+
@normalizer
459+
def cumprod(a: ArrayLike, axis: AxisLike = None, dtype: DTypeLike = None, out=None):
460+
result = _reductions.cumprod(a, axis=axis, dtype=dtype)
461+
return _helpers.result_or_out(result, out)

torch_np/_ndarray.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,6 @@
33
import torch
44

55
from . import _binary_ufuncs, _dtypes, _funcs, _helpers, _unary_ufuncs
6-
from ._decorators import (
7-
NoValue,
8-
axis_keepdims_wrapper,
9-
axis_none_ravel_wrapper,
10-
dtype_to_torch,
11-
emulate_out_arg,
12-
)
136
from ._detail import _dtypes_impl, _flips, _reductions, _util
147
from ._detail import implementations as _impl
158

@@ -389,12 +382,8 @@ def sort(self, axis=-1, kind=None, order=None):
389382
var = _funcs.var
390383
std = _funcs.std
391384

392-
cumprod = emulate_out_arg(
393-
axis_none_ravel_wrapper(dtype_to_torch(_reductions.cumprod))
394-
)
395-
cumsum = emulate_out_arg(
396-
axis_none_ravel_wrapper(dtype_to_torch(_reductions.cumsum))
397-
)
385+
cumsum = _funcs.cumsum
386+
cumprod = _funcs.cumprod
398387

399388
### indexing ###
400389
@staticmethod

0 commit comments

Comments
 (0)