44
44
__all__ = ["count_nonzero" ]
45
45
46
46
47
- def count_nonzero (a , axis = None , * , keepdims = False ):
47
+ def count_nonzero (a , axis = None , * , keepdims = False , out = None ):
48
48
"""
49
49
Counts the number of non-zero values in the array `a`.
50
50
51
51
For full documentation refer to :obj:`numpy.count_nonzero`.
52
52
53
+ Parameters
54
+ ----------
55
+ a : {dpnp.ndarray, usm_ndarray}
56
+ The array for which to count non-zeros.
57
+ axis : {None, int, tuple}, optional
58
+ Axis or tuple of axes along which to count non-zeros.
59
+ Default value means that non-zeros will be counted along a flattened
60
+ version of `a`.
61
+ Default: ``None``.
62
+ keepdims : bool, optional
63
+ If this is set to ``True``, the axes that are counted are left in the
64
+ result as dimensions with size one. With this option, the result will
65
+ broadcast correctly against the input array.
66
+ Default: ``False``.
67
+ out : {None, dpnp.ndarray, usm_ndarray}, optional
68
+ The array into which the result is written. The data type of `out` must
69
+ match the expected shape and the expected data type of the result.
70
+ If ``None`` then a new array is returned.
71
+ Default: ``None``.
72
+
53
73
Returns
54
74
-------
55
75
out : dpnp.ndarray
56
76
Number of non-zero values in the array along a given axis.
57
- Otherwise, a zero-dimensional array with the total number of
58
- non-zero values in the array is returned.
59
-
60
- Limitations
61
- -----------
62
- Parameters `a` is supported as either :class:`dpnp.ndarray`
63
- or :class:`dpctl.tensor.usm_ndarray`.
64
- Otherwise ``TypeError`` exception will be raised.
65
- Input array data types are limited by supported DPNP :ref:`Data types`.
77
+ Otherwise, a zero-dimensional array with the total number of non-zero
78
+ values in the array is returned.
66
79
67
80
See Also
68
81
--------
@@ -87,8 +100,10 @@ def count_nonzero(a, axis=None, *, keepdims=False):
87
100
88
101
"""
89
102
90
- # TODO: might be improved by implementing an extension
91
- # with `count_nonzero` kernel
92
103
usm_a = dpnp .get_usm_ndarray (a )
93
- usm_a = dpt .astype (usm_a , dpnp .bool , copy = False )
94
- return dpnp .sum (usm_a , axis = axis , dtype = dpnp .intp , keepdims = keepdims )
104
+ usm_out = None if out is None else dpnp .get_usm_ndarray (out )
105
+
106
+ usm_res = dpt .count_nonzero (
107
+ usm_a , axis = axis , keepdims = keepdims , out = usm_out
108
+ )
109
+ return dpnp .get_result_array (usm_res , out )
0 commit comments