Skip to content

Commit cebf600

Browse files
committed
MAINT: count_nonzero
1 parent 5dc5b5a commit cebf600

File tree

3 files changed

+8
-6
lines changed

3 files changed

+8
-6
lines changed

torch_np/_detail/_reductions.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ def _atleast_float(dtype, other_dtype):
6565
return dtype
6666

6767

68+
@emulate_keepdims
69+
@deco_axis_expand
6870
def count_nonzero(a, axis=None):
6971
# XXX: this all should probably be generalized to a sum(a != 0, dtype=bool)
7072
try:

torch_np/_funcs.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -442,3 +442,9 @@ def any(a: ArrayLike, axis: AxisLike=None, out=None, keepdims=NoValue, *, where=
442442
result = _reductions.any(a, axis=axis, where=where, keepdims=keepdims)
443443
return _helpers.result_or_out(result, out)
444444

445+
446+
@normalizer
447+
def count_nonzero(a: ArrayLike, axis: AxisLike=None, *, keepdims=False):
448+
result = _reductions.count_nonzero(a, axis=axis, keepdims=keepdims)
449+
return _helpers.array_from(result)
450+

torch_np/_wrapper.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -582,12 +582,6 @@ def flatnonzero(a):
582582
return _funcs.nonzero(arr.ravel())[0]
583583

584584

585-
from ._decorators import emulate_out_arg
586-
from ._ndarray import axis_keepdims_wrapper
587-
588-
count_nonzero = emulate_out_arg(axis_keepdims_wrapper(_reductions.count_nonzero))
589-
590-
591585
@normalizer
592586
def roll(a: ArrayLike, shift, axis=None):
593587
result = _impl.roll(a, shift, axis)

0 commit comments

Comments
 (0)