Skip to content

Commit 0f633af

Browse files
committed
Leverage on dpctl.tensor implementation of count_nonzero
1 parent 0c3dfe5 commit 0f633af

File tree

2 files changed

+32
-15
lines changed

2 files changed

+32
-15
lines changed

dpnp/dpnp_iface.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -654,7 +654,7 @@ def get_result_array(a, out=None, casting="safe"):
654654
655655
Parameters
656656
----------
657-
a : {dpnp_array}
657+
a : {dpnp.ndarray, usm_ndarray}
658658
Input array.
659659
out : {dpnp.ndarray, usm_ndarray}
660660
If provided, value of `a` array will be copied into it
@@ -671,6 +671,8 @@ def get_result_array(a, out=None, casting="safe"):
671671
"""
672672

673673
if out is None:
674+
if isinstance(a, dpt.usm_ndarray):
675+
return dpnp_array._create_from_usm_ndarray(a)
674676
return a
675677

676678
if isinstance(out, dpt.usm_ndarray):

dpnp/dpnp_iface_counting.py

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -44,25 +44,38 @@
4444
__all__ = ["count_nonzero"]
4545

4646

47-
def count_nonzero(a, axis=None, *, keepdims=False):
47+
def count_nonzero(a, axis=None, *, keepdims=False, out=None):
4848
"""
4949
Counts the number of non-zero values in the array `a`.
5050
5151
For full documentation refer to :obj:`numpy.count_nonzero`.
5252
53+
Parameters
54+
----------
55+
a : {dpnp.ndarray, usm_ndarray}
56+
The array for which to count non-zeros.
57+
axis : {None, int, tuple}, optional
58+
Axis or tuple of axes along which to count non-zeros.
59+
Default value means that non-zeros will be counted along a flattened
60+
version of `a`.
61+
Default: ``None``.
62+
keepdims : bool, optional
63+
If this is set to ``True``, the axes that are counted are left in the
64+
result as dimensions with size one. With this option, the result will
65+
broadcast correctly against the input array.
66+
Default: ``False``.
67+
out : {None, dpnp.ndarray, usm_ndarray}, optional
68+
The array into which the result is written. The data type of `out` must
69+
match the expected shape and the expected data type of the result.
70+
If ``None`` then a new array is returned.
71+
Default: ``None``.
72+
5373
Returns
5474
-------
5575
out : dpnp.ndarray
5676
Number of non-zero values in the array along a given axis.
57-
Otherwise, a zero-dimensional array with the total number of
58-
non-zero values in the array is returned.
59-
60-
Limitations
61-
-----------
62-
Parameters `a` is supported as either :class:`dpnp.ndarray`
63-
or :class:`dpctl.tensor.usm_ndarray`.
64-
Otherwise ``TypeError`` exception will be raised.
65-
Input array data types are limited by supported DPNP :ref:`Data types`.
77+
Otherwise, a zero-dimensional array with the total number of non-zero
78+
values in the array is returned.
6679
6780
See Also
6881
--------
@@ -87,8 +100,10 @@ def count_nonzero(a, axis=None, *, keepdims=False):
87100
88101
"""
89102

90-
# TODO: might be improved by implementing an extension
91-
# with `count_nonzero` kernel
92103
usm_a = dpnp.get_usm_ndarray(a)
93-
usm_a = dpt.astype(usm_a, dpnp.bool, copy=False)
94-
return dpnp.sum(usm_a, axis=axis, dtype=dpnp.intp, keepdims=keepdims)
104+
usm_out = None if out is None else dpnp.get_usm_ndarray(out)
105+
106+
usm_res = dpt.count_nonzero(
107+
usm_a, axis=axis, keepdims=keepdims, out=usm_out
108+
)
109+
return dpnp.get_result_array(usm_res, out)

0 commit comments

Comments
 (0)