Skip to content

Commit 323ce50

Browse files
Support out parameter for dpnp.all/any() (#1893)
* Update dpnp.all/any with support out param * Update cupy tests * Add TestAllAny * Update dpnp tests * Apply comments --------- Co-authored-by: Anton <[email protected]>
1 parent c78f28a commit 323ce50

File tree

5 files changed

+200
-105
lines changed

5 files changed

+200
-105
lines changed

dpnp/dpnp_iface_logic.py

Lines changed: 125 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646

4747

4848
import dpctl.tensor as dpt
49-
import dpctl.tensor._tensor_elementwise_impl as ti
49+
import dpctl.tensor._tensor_elementwise_impl as tei
5050
import numpy
5151

5252
import dpnp
@@ -76,25 +76,48 @@
7676
]
7777

7878

79-
def all(x, /, axis=None, out=None, keepdims=False, *, where=True):
79+
def all(a, /, axis=None, out=None, keepdims=False, *, where=True):
8080
"""
8181
Test whether all array elements along a given axis evaluate to True.
8282
8383
For full documentation refer to :obj:`numpy.all`.
8484
85+
Parameters
86+
----------
87+
a : {dpnp.ndarray, usm_ndarray}
88+
Input array.
89+
axis : {None, int, tuple of ints}, optional
90+
Axis or axes along which a logical AND reduction is performed.
91+
The default is to perform a logical AND over all the dimensions
92+
of the input array.`axis` may be negative, in which case it counts
93+
from the last to the first axis.
94+
Default: ``None``.
95+
out : {None, dpnp.ndarray, usm_ndarray}, optional
96+
Alternative output array in which to place the result. It must have
97+
the same shape as the expected output but the type (of the returned
98+
values) will be cast if necessary.
99+
Default: ``None``.
100+
keepdims : bool, optional
101+
If ``True``, the reduced axes (dimensions) are included in the result
102+
as singleton dimensions, so that the returned array remains
103+
compatible with the input array according to Array Broadcasting
104+
rules. Otherwise, if ``False``, the reduced axes are not included in
105+
the returned array.
106+
Default: ``False``.
107+
85108
Returns
86109
-------
87110
out : dpnp.ndarray
88111
An array with a data type of `bool`
89-
containing the results of the logical AND reduction.
112+
containing the results of the logical AND reduction is returned
113+
unless `out` is specified. Otherwise, a reference to `out` is returned.
114+
The result has the same shape as `a` if `axis` is not ``None``
115+
or `a` is a 0-d array.
90116
91117
Limitations
92118
-----------
93-
Parameters `x` is supported either as :class:`dpnp.ndarray`
94-
or :class:`dpctl.tensor.usm_ndarray`.
95-
Parameters `out` and `where` are supported with default value.
96-
Input array data types are limited by supported DPNP :ref:`Data types`.
97-
Otherwise the function will be executed sequentially on CPU.
119+
Parameters `where` is only supported with its default value.
120+
Otherwise ``NotImplementedError`` exception will be raised.
98121
99122
See Also
100123
--------
@@ -105,7 +128,7 @@ def all(x, /, axis=None, out=None, keepdims=False, *, where=True):
105128
Notes
106129
-----
107130
Not a Number (NaN), positive infinity and negative infinity
108-
evaluate to `True` because these are not equal to zero.
131+
evaluate to ``True`` because these are not equal to zero.
109132
110133
Examples
111134
--------
@@ -125,22 +148,27 @@ def all(x, /, axis=None, out=None, keepdims=False, *, where=True):
125148
>>> np.all(x3)
126149
array(True)
127150
151+
>>> o = np.array(False)
152+
>>> z = np.all(x2, out=o)
153+
>>> z, o
154+
(array(True), array(True))
155+
>>> # Check now that `z` is a reference to `o`
156+
>>> z is o
157+
True
158+
>>> id(z), id(o) # identity of `z` and `o`
159+
(139884456208480, 139884456208480) # may vary
160+
128161
"""
129162

130-
if dpnp.is_supported_array_type(x):
131-
if out is not None:
132-
pass
133-
elif where is not True:
134-
pass
135-
else:
136-
dpt_array = dpnp.get_usm_ndarray(x)
137-
return dpnp_array._create_from_usm_ndarray(
138-
dpt.all(dpt_array, axis=axis, keepdims=keepdims)
139-
)
163+
dpnp.check_limitations(where=where)
140164

141-
return call_origin(
142-
numpy.all, x, axis=axis, out=out, keepdims=keepdims, where=where
165+
dpt_array = dpnp.get_usm_ndarray(a)
166+
result = dpnp_array._create_from_usm_ndarray(
167+
dpt.all(dpt_array, axis=axis, keepdims=keepdims)
143168
)
169+
# TODO: temporary solution until dpt.all supports out parameter
170+
result = dpnp.get_result_array(result, out)
171+
return result
144172

145173

146174
def allclose(a, b, rtol=1.0e-5, atol=1.0e-8, **kwargs):
@@ -238,25 +266,48 @@ def allclose(a, b, rtol=1.0e-5, atol=1.0e-8, **kwargs):
238266
return call_origin(numpy.allclose, a, b, rtol=rtol, atol=atol, **kwargs)
239267

240268

241-
def any(x, /, axis=None, out=None, keepdims=False, *, where=True):
269+
def any(a, /, axis=None, out=None, keepdims=False, *, where=True):
242270
"""
243271
Test whether any array element along a given axis evaluates to True.
244272
245273
For full documentation refer to :obj:`numpy.any`.
246274
275+
Parameters
276+
----------
277+
a : {dpnp.ndarray, usm_ndarray}
278+
Input array.
279+
axis : {None, int, tuple of ints}, optional
280+
Axis or axes along which a logical OR reduction is performed.
281+
The default is to perform a logical OR over all the dimensions
282+
of the input array.`axis` may be negative, in which case it counts
283+
from the last to the first axis.
284+
Default: ``None``.
285+
out : {None, dpnp.ndarray, usm_ndarray}, optional
286+
Alternative output array in which to place the result. It must have
287+
the same shape as the expected output but the type (of the returned
288+
values) will be cast if necessary.
289+
Default: ``None``.
290+
keepdims : bool, optional
291+
If ``True``, the reduced axes (dimensions) are included in the result
292+
as singleton dimensions, so that the returned array remains
293+
compatible with the input array according to Array Broadcasting
294+
rules. Otherwise, if ``False``, the reduced axes are not included in
295+
the returned array.
296+
Default: ``False``.
297+
247298
Returns
248299
-------
249300
out : dpnp.ndarray
250301
An array with a data type of `bool`
251-
containing the results of the logical OR reduction.
302+
containing the results of the logical OR reduction is returned
303+
unless `out` is specified. Otherwise, a reference to `out` is returned.
304+
The result has the same shape as `a` if `axis` is not ``None``
305+
or `a` is a 0-d array.
252306
253307
Limitations
254308
-----------
255-
Parameters `x` is supported either as :class:`dpnp.ndarray`
256-
or :class:`dpctl.tensor.usm_ndarray`.
257-
Parameters `out` and `where` are supported with default value.
258-
Input array data types are limited by supported DPNP :ref:`Data types`.
259-
Otherwise the function will be executed sequentially on CPU.
309+
Parameters `where` is only supported with its default value.
310+
Otherwise ``NotImplementedError`` exception will be raised.
260311
261312
See Also
262313
--------
@@ -267,7 +318,7 @@ def any(x, /, axis=None, out=None, keepdims=False, *, where=True):
267318
Notes
268319
-----
269320
Not a Number (NaN), positive infinity and negative infinity evaluate
270-
to `True` because these are not equal to zero.
321+
to ``True`` because these are not equal to zero.
271322
272323
Examples
273324
--------
@@ -279,30 +330,35 @@ def any(x, /, axis=None, out=None, keepdims=False, *, where=True):
279330
>>> np.any(x, axis=0)
280331
array([ True, True])
281332
282-
>>> x2 = np.array([0, 0, 0])
333+
>>> x2 = np.array([-1, 0, 5])
283334
>>> np.any(x2)
284-
array(False)
335+
array(True)
285336
286337
>>> x3 = np.array([1.0, np.nan])
287338
>>> np.any(x3)
288339
array(True)
289340
341+
>>> o = np.array(False)
342+
>>> z = np.any(x2, out=o)
343+
>>> z, o
344+
(array(True), array(True))
345+
>>> # Check now that `z` is a reference to `o`
346+
>>> z is o
347+
True
348+
>>> id(z), id(o) # identity of `z` and `o`
349+
>>> (140053638309840, 140053638309840) # may vary
350+
290351
"""
291352

292-
if dpnp.is_supported_array_type(x):
293-
if out is not None:
294-
pass
295-
elif where is not True:
296-
pass
297-
else:
298-
dpt_array = dpnp.get_usm_ndarray(x)
299-
return dpnp_array._create_from_usm_ndarray(
300-
dpt.any(dpt_array, axis=axis, keepdims=keepdims)
301-
)
353+
dpnp.check_limitations(where=where)
302354

303-
return call_origin(
304-
numpy.any, x, axis=axis, out=out, keepdims=keepdims, where=where
355+
dpt_array = dpnp.get_usm_ndarray(a)
356+
result = dpnp_array._create_from_usm_ndarray(
357+
dpt.any(dpt_array, axis=axis, keepdims=keepdims)
305358
)
359+
# TODO: temporary solution until dpt.any supports out parameter
360+
result = dpnp.get_result_array(result, out)
361+
return result
306362

307363

308364
_EQUAL_DOCSTRING = """
@@ -368,8 +424,8 @@ def any(x, /, axis=None, out=None, keepdims=False, *, where=True):
368424

369425
equal = DPNPBinaryFunc(
370426
"equal",
371-
ti._equal_result_type,
372-
ti._equal,
427+
tei._equal_result_type,
428+
tei._equal,
373429
_EQUAL_DOCSTRING,
374430
)
375431

@@ -431,8 +487,8 @@ def any(x, /, axis=None, out=None, keepdims=False, *, where=True):
431487

432488
greater = DPNPBinaryFunc(
433489
"greater",
434-
ti._greater_result_type,
435-
ti._greater,
490+
tei._greater_result_type,
491+
tei._greater,
436492
_GREATER_DOCSTRING,
437493
)
438494

@@ -495,8 +551,8 @@ def any(x, /, axis=None, out=None, keepdims=False, *, where=True):
495551

496552
greater_equal = DPNPBinaryFunc(
497553
"greater",
498-
ti._greater_equal_result_type,
499-
ti._greater_equal,
554+
tei._greater_equal_result_type,
555+
tei._greater_equal,
500556
_GREATER_EQUAL_DOCSTRING,
501557
)
502558

@@ -597,8 +653,8 @@ def isclose(x1, x2, rtol=1e-05, atol=1e-08, equal_nan=False):
597653

598654
isfinite = DPNPUnaryFunc(
599655
"isfinite",
600-
ti._isfinite_result_type,
601-
ti._isfinite,
656+
tei._isfinite_result_type,
657+
tei._isfinite,
602658
_ISFINITE_DOCSTRING,
603659
)
604660

@@ -650,8 +706,8 @@ def isclose(x1, x2, rtol=1e-05, atol=1e-08, equal_nan=False):
650706

651707
isinf = DPNPUnaryFunc(
652708
"isinf",
653-
ti._isinf_result_type,
654-
ti._isinf,
709+
tei._isinf_result_type,
710+
tei._isinf,
655711
_ISINF_DOCSTRING,
656712
)
657713

@@ -704,8 +760,8 @@ def isclose(x1, x2, rtol=1e-05, atol=1e-08, equal_nan=False):
704760

705761
isnan = DPNPUnaryFunc(
706762
"isnan",
707-
ti._isnan_result_type,
708-
ti._isnan,
763+
tei._isnan_result_type,
764+
tei._isnan,
709765
_ISNAN_DOCSTRING,
710766
)
711767

@@ -767,8 +823,8 @@ def isclose(x1, x2, rtol=1e-05, atol=1e-08, equal_nan=False):
767823

768824
less = DPNPBinaryFunc(
769825
"less",
770-
ti._less_result_type,
771-
ti._less,
826+
tei._less_result_type,
827+
tei._less,
772828
_LESS_DOCSTRING,
773829
)
774830

@@ -830,8 +886,8 @@ def isclose(x1, x2, rtol=1e-05, atol=1e-08, equal_nan=False):
830886

831887
less_equal = DPNPBinaryFunc(
832888
"less_equal",
833-
ti._less_equal_result_type,
834-
ti._less_equal,
889+
tei._less_equal_result_type,
890+
tei._less_equal,
835891
_LESS_EQUAL_DOCSTRING,
836892
)
837893

@@ -895,8 +951,8 @@ def isclose(x1, x2, rtol=1e-05, atol=1e-08, equal_nan=False):
895951

896952
logical_and = DPNPBinaryFunc(
897953
"logical_and",
898-
ti._logical_and_result_type,
899-
ti._logical_and,
954+
tei._logical_and_result_type,
955+
tei._logical_and,
900956
_LOGICAL_AND_DOCSTRING,
901957
)
902958

@@ -947,8 +1003,8 @@ def isclose(x1, x2, rtol=1e-05, atol=1e-08, equal_nan=False):
9471003

9481004
logical_not = DPNPUnaryFunc(
9491005
"logical_not",
950-
ti._logical_not_result_type,
951-
ti._logical_not,
1006+
tei._logical_not_result_type,
1007+
tei._logical_not,
9521008
_LOGICAL_NOT_DOCSTRING,
9531009
)
9541010

@@ -1012,8 +1068,8 @@ def isclose(x1, x2, rtol=1e-05, atol=1e-08, equal_nan=False):
10121068

10131069
logical_or = DPNPBinaryFunc(
10141070
"logical_or",
1015-
ti._logical_or_result_type,
1016-
ti._logical_or,
1071+
tei._logical_or_result_type,
1072+
tei._logical_or,
10171073
_LOGICAL_OR_DOCSTRING,
10181074
)
10191075

@@ -1075,8 +1131,8 @@ def isclose(x1, x2, rtol=1e-05, atol=1e-08, equal_nan=False):
10751131

10761132
logical_xor = DPNPBinaryFunc(
10771133
"logical_xor",
1078-
ti._logical_xor_result_type,
1079-
ti._logical_xor,
1134+
tei._logical_xor_result_type,
1135+
tei._logical_xor,
10801136
_LOGICAL_XOR_DOCSTRING,
10811137
)
10821138

@@ -1138,7 +1194,7 @@ def isclose(x1, x2, rtol=1e-05, atol=1e-08, equal_nan=False):
11381194

11391195
not_equal = DPNPBinaryFunc(
11401196
"not_equal",
1141-
ti._not_equal_result_type,
1142-
ti._not_equal,
1197+
tei._not_equal_result_type,
1198+
tei._not_equal,
11431199
_NOT_EQUAL_DOCSTRING,
11441200
)

0 commit comments

Comments
 (0)