Skip to content

Commit 9207515

Browse files
committed
Applied review comments
1 parent 1ea5551 commit 9207515

File tree

3 files changed

+79
-187
lines changed

3 files changed

+79
-187
lines changed

doc/reference/logic.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ Comparison
6868
dpnp.allclose
6969
dpnp.isclose
7070
dpnp.array_equal
71+
dpnp.array_equiv
7172
dpnp.greater
7273
dpnp.greater_equal
7374
dpnp.less

dpnp/dpnp_iface_logic.py

Lines changed: 34 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,10 @@
5050
import numpy
5151

5252
import dpnp
53-
import dpnp.dpnp_utils as utils
5453
from dpnp.dpnp_algo.dpnp_elementwise_common import DPNPBinaryFunc, DPNPUnaryFunc
5554

55+
from .dpnp_utils import get_usm_allocations
56+
5657
__all__ = [
5758
"all",
5859
"allclose",
@@ -115,7 +116,7 @@ def all(a, /, axis=None, out=None, keepdims=False, *, where=True):
115116
Returns
116117
-------
117118
out : dpnp.ndarray
118-
An array with a data type of `bool`
119+
An array with a data type of `bool`.
119120
containing the results of the logical AND reduction is returned
120121
unless `out` is specified. Otherwise, a reference to `out` is returned.
121122
The result has the same shape as `a` if `axis` is not ``None``
@@ -279,7 +280,7 @@ def any(a, /, axis=None, out=None, keepdims=False, *, where=True):
279280
Returns
280281
-------
281282
out : dpnp.ndarray
282-
An array with a data type of `bool`
283+
An array with a data type of `bool`.
283284
containing the results of the logical OR reduction is returned
284285
unless `out` is specified. Otherwise, a reference to `out` is returned.
285286
The result has the same shape as `a` if `axis` is not ``None``
@@ -345,23 +346,26 @@ def array_equal(a1, a2, equal_nan=False):
345346
``True`` if two arrays have the same shape and elements, ``False``
346347
otherwise.
347348
349+
For full documentation refer to :obj:`numpy.array_equal`.
350+
348351
Parameters
349352
----------
350353
a1 : {dpnp.ndarray, usm_ndarray, scalar}
351-
First input array, expected to have numeric data type.
354+
First input array.
352355
Both inputs `x1` and `x2` can not be scalars at the same time.
353356
a2 : {dpnp.ndarray, usm_ndarray, scalar}
354-
Second input array, also expected to have numeric data type.
357+
Second input array.
355358
Both inputs `x1` and `x2` can not be scalars at the same time.
356-
equal_nan : bool
359+
equal_nan : bool, optional
357360
Whether to compare ``NaNs`` as equal. If the dtype of `a1` and `a2` is
358361
complex, values will be considered equal if either the real or the
359-
imaginary component of a given value is ``NaNs``.
362+
imaginary component of a given value is ``NaN``.
363+
Default: ``False``.
360364
361365
Returns
362366
-------
363367
b : dpnp.ndarray
364-
An array with a data type of `bool`
368+
An array with a data type of `bool`.
365369
Returns ``True`` if the arrays are equal.
366370
367371
See Also
@@ -379,12 +383,10 @@ def array_equal(a1, a2, equal_nan=False):
379383
>>> np.array_equal(a, b)
380384
array(True)
381385
382-
>>> a = np.array([1, 2])
383386
>>> b = np.array([1, 2, 3])
384387
>>> np.array_equal(a, b)
385388
array(False)
386389
387-
>>> a = np.array([1, 2])
388390
>>> b = np.array([1, 4])
389391
>>> np.array_equal(a, b)
390392
array(False)
@@ -397,7 +399,8 @@ def array_equal(a1, a2, equal_nan=False):
397399
array(True)
398400
399401
When ``equal_nan`` is ``True``, complex values with nan components are
400-
considered equal if either the real *or* the imaginary components are nan.
402+
considered equal if either the real *or* the imaginary components are
403+
``NaNs``.
401404
402405
>>> a = np.array([1 + 1j])
403406
>>> b = a.copy()
@@ -407,6 +410,7 @@ def array_equal(a1, a2, equal_nan=False):
407410
array(True)
408411
409412
"""
413+
410414
dpnp.check_supported_arrays_type(a1, a2, scalar_type=True)
411415
if dpnp.isscalar(a1):
412416
usm_type_alloc = a2.usm_type
@@ -427,7 +431,7 @@ def array_equal(a1, a2, equal_nan=False):
427431
sycl_queue=sycl_queue_alloc,
428432
)
429433
else:
430-
usm_type_alloc, sycl_queue_alloc = utils.get_usm_allocations([a1, a2])
434+
usm_type_alloc, sycl_queue_alloc = get_usm_allocations([a1, a2])
431435

432436
if a1.shape != a2.shape:
433437
return dpnp.array(
@@ -438,15 +442,14 @@ def array_equal(a1, a2, equal_nan=False):
438442
return (a1 == a2).all()
439443

440444
if a1 is a2:
445+
# NaN will compare equal so an array will compare equal to itself
441446
return dpnp.array(
442447
True, usm_type=usm_type_alloc, sycl_queue=sycl_queue_alloc
443448
)
444449

445-
cannot_have_nan = (
446-
dpnp.issubdtype(a1, dpnp.bool) or dpnp.issubdtype(a1, dpnp.integer)
447-
) and (dpnp.issubdtype(a2, dpnp.bool) or dpnp.issubdtype(a2, dpnp.integer))
448-
449-
if cannot_have_nan:
450+
if not (
451+
dpnp.issubdtype(a1, dpnp.inexact) or dpnp.issubdtype(a2, dpnp.inexact)
452+
):
450453
return (a1 == a2).all()
451454

452455
# Handling NaN values if equal_nan is True
@@ -468,19 +471,21 @@ def array_equiv(a1, a2):
468471
Shape consistent means they are either the same shape, or one input array
469472
can be broadcasted to create the same shape as the other one.
470473
474+
For full documentation refer to :obj:`numpy.array_equiv`.
475+
471476
Parameters
472477
----------
473478
a1 : {dpnp.ndarray, usm_ndarray, scalar}
474-
First input array, expected to have numeric data type.
479+
First input array.
475480
Both inputs `x1` and `x2` can not be scalars at the same time.
476481
a2 : {dpnp.ndarray, usm_ndarray, scalar}
477-
Second input array, also expected to have numeric data type.
482+
Second input array.
478483
Both inputs `x1` and `x2` can not be scalars at the same time.
479484
480485
Returns
481486
-------
482487
out : dpnp.ndarray
483-
An array with a data type of `bool`
488+
An array with a data type of `bool`.
484489
``True`` if equivalent, ``False`` otherwise.
485490
486491
Examples
@@ -496,48 +501,28 @@ def array_equiv(a1, a2):
496501
497502
Showing the shape equivalence:
498503
499-
>>> a = np.array([1, 2])
500504
>>> b = np.array([[1, 2], [1, 2]])
501505
>>> c = np.array([[1, 2, 1, 2], [1, 2, 1, 2]])
502506
>>> np.array_equiv(a, b)
503507
array(True)
504508
>>> np.array_equiv(a, c)
505509
array(False)
506510
507-
>>> a = np.array([1, 2])
508511
>>> b = np.array([[1, 2], [1, 3]])
509512
>>> np.array_equiv(a, b)
510513
array(False)
511514
512515
"""
513-
dpnp.check_supported_arrays_type(a1, a2, scalar_type=True)
514-
if dpnp.isscalar(a1):
515-
usm_type_alloc = a2.usm_type
516-
sycl_queue_alloc = a2.sycl_queue
517-
a1 = dpnp.array(
518-
a1,
519-
dtype=dpnp.result_type(a1, a2),
520-
usm_type=usm_type_alloc,
521-
sycl_queue=sycl_queue_alloc,
522-
)
523-
elif dpnp.isscalar(a2):
524-
usm_type_alloc = a1.usm_type
525-
sycl_queue_alloc = a1.sycl_queue
526-
a2 = dpnp.array(
527-
a2,
528-
dtype=dpnp.result_type(a1, a2),
529-
usm_type=usm_type_alloc,
530-
sycl_queue=sycl_queue_alloc,
531-
)
532-
else:
533-
usm_type_alloc, sycl_queue_alloc = utils.get_usm_allocations([a1, a2])
534516

535-
try:
536-
dpnp.broadcast_arrays(a1, a2)
537-
except ValueError:
538-
return dpnp.array(
539-
False, usm_type=usm_type_alloc, sycl_queue=sycl_queue_alloc
540-
)
517+
dpnp.check_supported_arrays_type(a1, a2, scalar_type=True)
518+
if not dpnp.isscalar(a1) and not dpnp.isscalar(a2):
519+
usm_type_alloc, sycl_queue_alloc = get_usm_allocations([a1, a2])
520+
try:
521+
dpnp.broadcast_arrays(a1, a2)
522+
except ValueError:
523+
return dpnp.array(
524+
False, usm_type=usm_type_alloc, sycl_queue=sycl_queue_alloc
525+
)
541526
return (a1 == a2).all()
542527

543528

tests/test_logic.py

Lines changed: 44 additions & 138 deletions
Original file line numberDiff line numberDiff line change
@@ -496,141 +496,47 @@ def test_isclose(dtype, rtol, atol):
496496
assert_allclose(dpnp_res, np_res)
497497

498498

499-
def _test_array_equal_parametrizations():
500-
"""
501-
we pre-create arrays as we sometime want to pass the same instance
502-
and sometime not. Passing the same instances may not mean the array are
503-
equal, especially when containing None
504-
"""
505-
# those are 0-d arrays, it used to be a special case
506-
# where (e0 == e0).all() would raise
507-
e0 = dpnp.array(0, dtype="i4")
508-
e1 = dpnp.array(1, dtype="f4")
509-
# x,y, nan_equal, expected_result
510-
yield (e0, e0.copy(), None, True)
511-
yield (e0, e0.copy(), False, True)
512-
yield (e0, e0.copy(), True, True)
513-
514-
#
515-
yield (e1, e1.copy(), None, True)
516-
yield (e1, e1.copy(), False, True)
517-
yield (e1, e1.copy(), True, True)
518-
519-
# Non-nanable – those cannot hold nans
520-
a12 = dpnp.array([1, 2])
521-
a12b = a12.copy()
522-
a123 = dpnp.array([1, 2, 3])
523-
a13 = dpnp.array([1, 3])
524-
a34 = dpnp.array([3, 4])
525-
526-
yield (a12, a12b, None, True)
527-
yield (a12, a12, None, True)
528-
yield (a12, a123, None, False)
529-
yield (a12, a34, None, False)
530-
yield (a12, a13, None, False)
531-
532-
# Non-float dtype - equal_nan should have no effect,
533-
yield (a123, a123, None, True)
534-
yield (a123, a123, False, True)
535-
yield (a123, a123, True, True)
536-
yield (a123, a123.copy(), None, True)
537-
yield (a123, a123.copy(), False, True)
538-
yield (a123, a123.copy(), True, True)
539-
yield (a123.astype("f4"), a123.astype("f4"), None, True)
540-
yield (a123.astype("f4"), a123.astype("f4"), False, True)
541-
yield (a123.astype("f4"), a123.astype("f4"), True, True)
542-
543-
# these can hold None
544-
b1 = dpnp.array([1, 2, dpnp.nan])
545-
b2 = dpnp.array([1, dpnp.nan, 2])
546-
b3 = dpnp.array([1, 2, dpnp.inf])
547-
b4 = dpnp.array(dpnp.nan)
548-
549-
# instances are the same
550-
yield (b1, b1, None, False)
551-
yield (b1, b1, False, False)
552-
yield (b1, b1, True, True)
553-
554-
# equal but not same instance
555-
yield (b1, b1.copy(), None, False)
556-
yield (b1, b1.copy(), False, False)
557-
yield (b1, b1.copy(), True, True)
558-
559-
# same once stripped of Nan
560-
yield (b1, b2, None, False)
561-
yield (b1, b2, False, False)
562-
yield (b1, b2, True, False)
563-
564-
# nan's not conflated with inf's
565-
yield (b1, b3, None, False)
566-
yield (b1, b3, False, False)
567-
yield (b1, b3, True, False)
568-
569-
# all Nan
570-
yield (b4, b4, None, False)
571-
yield (b4, b4, False, False)
572-
yield (b4, b4, True, True)
573-
yield (b4, b4.copy(), None, False)
574-
yield (b4, b4.copy(), False, False)
575-
yield (b4, b4.copy(), True, True)
576-
577-
# Multi-dimensional array
578-
md1 = dpnp.array([[0, 1], [dpnp.nan, 1]])
579-
580-
yield (md1, md1, None, False)
581-
yield (md1, md1, False, False)
582-
yield (md1, md1, True, True)
583-
yield (md1, md1.copy(), None, False)
584-
yield (md1, md1.copy(), False, False)
585-
yield (md1, md1.copy(), True, True)
586-
# both complexes are nan+nan.j but the same instance
587-
cplx1, cplx2 = [dpnp.array([dpnp.nan + dpnp.nan * 1j])] * 2
588-
589-
# only real or img are nan.
590-
cplx3 = dpnp.array(1 + dpnp.nan * 1j)
591-
cplx4 = dpnp.array(dpnp.nan + 1j)
592-
593-
# Complex values
594-
yield (cplx1, cplx2, None, False)
595-
yield (cplx1, cplx2, False, False)
596-
yield (cplx1, cplx2, True, True)
597-
598-
# Complex values, 1+nan, nan+1j
599-
yield (cplx3, cplx4, None, False)
600-
yield (cplx3, cplx4, False, False)
601-
yield (cplx3, cplx4, True, True)
602-
603-
604-
class TestArrayComparisons:
605-
@pytest.mark.parametrize(
606-
"bx,by,equal_nan,expected", _test_array_equal_parametrizations()
607-
)
608-
def test_array_equal_equal_nan(self, bx, by, equal_nan, expected):
609-
if equal_nan is None:
610-
res = dpnp.array_equal(bx, by)
611-
else:
612-
res = dpnp.array_equal(bx, by, equal_nan=equal_nan)
613-
assert_equal(res, dpnp.array(expected))
614-
615-
def test_array_equiv(self):
616-
res = dpnp.array_equiv(dpnp.array([1, 2]), dpnp.array([1, 2]))
617-
assert res
618-
res = dpnp.array_equiv(dpnp.array([1, 2]), dpnp.array([1, 2, 3]))
619-
assert not res
620-
res = dpnp.array_equiv(dpnp.array([1, 2]), dpnp.array([3, 4]))
621-
assert not res
622-
res = dpnp.array_equiv(dpnp.array([1, 2]), dpnp.array([1, 3]))
623-
assert not res
624-
625-
res = dpnp.array_equiv(dpnp.array([1, 1]), dpnp.array([1]))
626-
assert res
627-
res = dpnp.array_equiv(dpnp.array([1, 1]), dpnp.array([[1], [1]]))
628-
assert res
629-
res = dpnp.array_equiv(dpnp.array([1, 2]), dpnp.array([2]))
630-
assert not res
631-
res = dpnp.array_equiv(dpnp.array([1, 2]), dpnp.array([[1], [2]]))
632-
assert not res
633-
res = dpnp.array_equiv(
634-
dpnp.array([1, 2]), dpnp.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
635-
)
636-
assert not res
499+
@pytest.mark.parametrize("a", [numpy.array([1, 2]), numpy.array([1, 1])])
500+
@pytest.mark.parametrize(
501+
"b",
502+
[
503+
numpy.array([1, 2]),
504+
numpy.array([1, 2, 3]),
505+
numpy.array([3, 4]),
506+
numpy.array([1, 3]),
507+
numpy.array([1]),
508+
numpy.array([[1], [1]]),
509+
numpy.array([2]),
510+
numpy.array([[1], [2]]),
511+
numpy.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]),
512+
],
513+
)
514+
def test_array_equiv(a, b):
515+
result = dpnp.array_equiv(dpnp.array(a), dpnp.array(b))
516+
expected = numpy.array_equiv(a, b)
517+
518+
assert_equal(expected, result)
519+
520+
521+
@pytest.mark.parametrize("a", [numpy.array([1, 2]), numpy.array([1, 1])])
522+
def test_array_equiv_scalar(a):
523+
b = 1
524+
result = dpnp.array_equiv(dpnp.array(a), b)
525+
expected = numpy.array_equiv(a, b)
526+
527+
assert_equal(expected, result)
528+
529+
530+
@pytest.mark.parametrize(
531+
"a",
532+
[
533+
numpy.array([1, 2]),
534+
numpy.array([1.0, numpy.nan]),
535+
numpy.array([1.0, numpy.inf]),
536+
],
537+
)
538+
def test_array_equal_same_arr(a):
539+
expected = numpy.array_equal(a, a.copy())
540+
b = dpnp.array(a)
541+
result = dpnp.array_equal(b, b.copy())
542+
assert_equal(expected, result)

0 commit comments

Comments
 (0)