Skip to content

Commit 7948a0a

Browse files
npolina4antonwolfy
andauthored
Implement dpnp.array_equal and dpnp.array_equiv (#1965)
* Implement dpnp.array_equal and dpnp.array_equiv * Added CFD * Applied review comments * Added tests --------- Co-authored-by: Anton <[email protected]>
1 parent 1e5ba88 commit 7948a0a

File tree

7 files changed

+297
-24
lines changed

7 files changed

+297
-24
lines changed

doc/reference/logic.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ Comparison
6868
dpnp.allclose
6969
dpnp.isclose
7070
dpnp.array_equal
71+
dpnp.array_equiv
7172
dpnp.greater
7273
dpnp.greater_equal
7374
dpnp.less

dpnp/dpnp_iface.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@
5656

5757
__all__ = [
5858
"are_same_logical_tensors",
59-
"array_equal",
6059
"asnumpy",
6160
"astype",
6261
"as_usm_ndarray",
@@ -173,24 +172,6 @@ def are_same_logical_tensors(ar1, ar2):
173172
)
174173

175174

176-
def array_equal(a1, a2, equal_nan=False):
177-
"""
178-
True if two arrays have the same shape and elements, False otherwise.
179-
180-
For full documentation refer to :obj:`numpy.array_equal`.
181-
182-
See Also
183-
--------
184-
:obj:`dpnp.allclose` : Returns True if two arrays are element-wise equal
185-
within a tolerance.
186-
:obj:`dpnp.array_equiv` : Returns True if input arrays are shape consistent
187-
and all elements equal.
188-
189-
"""
190-
191-
return numpy.array_equal(a1, a2, equal_nan=equal_nan)
192-
193-
194175
def asnumpy(a, order="C"):
195176
"""
196177
Returns the NumPy array with input data.

dpnp/dpnp_iface_logic.py

Lines changed: 191 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,14 @@
5252
import dpnp
5353
from dpnp.dpnp_algo.dpnp_elementwise_common import DPNPBinaryFunc, DPNPUnaryFunc
5454

55+
from .dpnp_utils import get_usm_allocations
56+
5557
__all__ = [
5658
"all",
5759
"allclose",
5860
"any",
61+
"array_equal",
62+
"array_equiv",
5963
"equal",
6064
"greater",
6165
"greater_equal",
@@ -112,7 +116,7 @@ def all(a, /, axis=None, out=None, keepdims=False, *, where=True):
112116
Returns
113117
-------
114118
out : dpnp.ndarray
115-
An array with a data type of `bool`
119+
An array with a data type of `bool`.
116120
containing the results of the logical AND reduction is returned
117121
unless `out` is specified. Otherwise, a reference to `out` is returned.
118122
The result has the same shape as `a` if `axis` is not ``None``
@@ -276,7 +280,7 @@ def any(a, /, axis=None, out=None, keepdims=False, *, where=True):
276280
Returns
277281
-------
278282
out : dpnp.ndarray
279-
An array with a data type of `bool`
283+
An array with a data type of `bool`.
280284
containing the results of the logical OR reduction is returned
281285
unless `out` is specified. Otherwise, a reference to `out` is returned.
282286
The result has the same shape as `a` if `axis` is not ``None``
@@ -337,6 +341,191 @@ def any(a, /, axis=None, out=None, keepdims=False, *, where=True):
337341
return dpnp.get_result_array(usm_res, out)
338342

339343

344+
def array_equal(a1, a2, equal_nan=False):
345+
"""
346+
``True`` if two arrays have the same shape and elements, ``False``
347+
otherwise.
348+
349+
For full documentation refer to :obj:`numpy.array_equal`.
350+
351+
Parameters
352+
----------
353+
a1 : {dpnp.ndarray, usm_ndarray, scalar}
354+
First input array.
355+
Both inputs `x1` and `x2` can not be scalars at the same time.
356+
a2 : {dpnp.ndarray, usm_ndarray, scalar}
357+
Second input array.
358+
Both inputs `x1` and `x2` can not be scalars at the same time.
359+
equal_nan : bool, optional
360+
Whether to compare ``NaNs`` as equal. If the dtype of `a1` and `a2` is
361+
complex, values will be considered equal if either the real or the
362+
imaginary component of a given value is ``NaN``.
363+
Default: ``False``.
364+
365+
Returns
366+
-------
367+
b : dpnp.ndarray
368+
An array with a data type of `bool`.
369+
Returns ``True`` if the arrays are equal.
370+
371+
See Also
372+
--------
373+
:obj:`dpnp.allclose`: Returns ``True`` if two arrays are element-wise equal
374+
within a tolerance.
375+
:obj:`dpnp.array_equiv`: Returns ``True`` if input arrays are shape
376+
consistent and all elements equal.
377+
378+
Examples
379+
--------
380+
>>> import dpnp as np
381+
>>> a = np.array([1, 2])
382+
>>> b = np.array([1, 2])
383+
>>> np.array_equal(a, b)
384+
array(True)
385+
386+
>>> b = np.array([1, 2, 3])
387+
>>> np.array_equal(a, b)
388+
array(False)
389+
390+
>>> b = np.array([1, 4])
391+
>>> np.array_equal(a, b)
392+
array(False)
393+
394+
>>> a = np.array([1, np.nan])
395+
>>> np.array_equal(a, a)
396+
array(False)
397+
398+
>>> np.array_equal(a, a, equal_nan=True)
399+
array(True)
400+
401+
When ``equal_nan`` is ``True``, complex values with nan components are
402+
considered equal if either the real *or* the imaginary components are
403+
``NaNs``.
404+
405+
>>> a = np.array([1 + 1j])
406+
>>> b = a.copy()
407+
>>> a.real = np.nan
408+
>>> b.imag = np.nan
409+
>>> np.array_equal(a, b, equal_nan=True)
410+
array(True)
411+
412+
"""
413+
414+
dpnp.check_supported_arrays_type(a1, a2, scalar_type=True)
415+
if dpnp.isscalar(a1):
416+
usm_type_alloc = a2.usm_type
417+
sycl_queue_alloc = a2.sycl_queue
418+
a1 = dpnp.array(
419+
a1,
420+
dtype=dpnp.result_type(a1, a2),
421+
usm_type=usm_type_alloc,
422+
sycl_queue=sycl_queue_alloc,
423+
)
424+
elif dpnp.isscalar(a2):
425+
usm_type_alloc = a1.usm_type
426+
sycl_queue_alloc = a1.sycl_queue
427+
a2 = dpnp.array(
428+
a2,
429+
dtype=dpnp.result_type(a1, a2),
430+
usm_type=usm_type_alloc,
431+
sycl_queue=sycl_queue_alloc,
432+
)
433+
else:
434+
usm_type_alloc, sycl_queue_alloc = get_usm_allocations([a1, a2])
435+
436+
if a1.shape != a2.shape:
437+
return dpnp.array(
438+
False, usm_type=usm_type_alloc, sycl_queue=sycl_queue_alloc
439+
)
440+
441+
if not equal_nan:
442+
return (a1 == a2).all()
443+
444+
if a1 is a2:
445+
# NaN will compare equal so an array will compare equal to itself
446+
return dpnp.array(
447+
True, usm_type=usm_type_alloc, sycl_queue=sycl_queue_alloc
448+
)
449+
450+
if not (
451+
dpnp.issubdtype(a1, dpnp.inexact) or dpnp.issubdtype(a2, dpnp.inexact)
452+
):
453+
return (a1 == a2).all()
454+
455+
# Handling NaN values if equal_nan is True
456+
a1nan, a2nan = isnan(a1), isnan(a2)
457+
# NaNs occur at different locations
458+
if not (a1nan == a2nan).all():
459+
return dpnp.array(
460+
False, usm_type=usm_type_alloc, sycl_queue=sycl_queue_alloc
461+
)
462+
# Shapes of a1, a2 and masks are guaranteed to be consistent by this point
463+
return (a1[~a1nan] == a2[~a1nan]).all()
464+
465+
466+
def array_equiv(a1, a2):
467+
"""
468+
Returns ``True`` if input arrays are shape consistent and all elements
469+
equal.
470+
471+
Shape consistent means they are either the same shape, or one input array
472+
can be broadcasted to create the same shape as the other one.
473+
474+
For full documentation refer to :obj:`numpy.array_equiv`.
475+
476+
Parameters
477+
----------
478+
a1 : {dpnp.ndarray, usm_ndarray, scalar}
479+
First input array.
480+
Both inputs `x1` and `x2` can not be scalars at the same time.
481+
a2 : {dpnp.ndarray, usm_ndarray, scalar}
482+
Second input array.
483+
Both inputs `x1` and `x2` can not be scalars at the same time.
484+
485+
Returns
486+
-------
487+
out : dpnp.ndarray
488+
An array with a data type of `bool`.
489+
``True`` if equivalent, ``False`` otherwise.
490+
491+
Examples
492+
--------
493+
>>> import dpnp as np
494+
>>> a = np.array([1, 2])
495+
>>> b = np.array([1, 2])
496+
>>> c = np.array([1, 3])
497+
>>> np.array_equiv(a, b)
498+
array(True)
499+
>>> np.array_equiv(a, c)
500+
array(False)
501+
502+
Showing the shape equivalence:
503+
504+
>>> b = np.array([[1, 2], [1, 2]])
505+
>>> c = np.array([[1, 2, 1, 2], [1, 2, 1, 2]])
506+
>>> np.array_equiv(a, b)
507+
array(True)
508+
>>> np.array_equiv(a, c)
509+
array(False)
510+
511+
>>> b = np.array([[1, 2], [1, 3]])
512+
>>> np.array_equiv(a, b)
513+
array(False)
514+
515+
"""
516+
517+
dpnp.check_supported_arrays_type(a1, a2, scalar_type=True)
518+
if not dpnp.isscalar(a1) and not dpnp.isscalar(a2):
519+
usm_type_alloc, sycl_queue_alloc = get_usm_allocations([a1, a2])
520+
try:
521+
dpnp.broadcast_arrays(a1, a2)
522+
except ValueError:
523+
return dpnp.array(
524+
False, usm_type=usm_type_alloc, sycl_queue=sycl_queue_alloc
525+
)
526+
return (a1 == a2).all()
527+
528+
340529
_EQUAL_DOCSTRING = """
341530
Calculates equality test results for each element `x1_i` of the input array `x1`
342531
with the respective element `x2_i` of the input array `x2`.

tests/test_logic.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -494,3 +494,104 @@ def test_isclose(dtype, rtol, atol):
494494
np_res = numpy.isclose(a, b, 1e-05, 1e-08)
495495
dpnp_res = dpnp.isclose(dpnp_a, dpnp_b, rtol, atol)
496496
assert_allclose(dpnp_res, np_res)
497+
498+
499+
@pytest.mark.parametrize("a", [numpy.array([1, 2]), numpy.array([1, 1])])
500+
@pytest.mark.parametrize(
501+
"b",
502+
[
503+
numpy.array([1, 2]),
504+
numpy.array([1, 2, 3]),
505+
numpy.array([3, 4]),
506+
numpy.array([1, 3]),
507+
numpy.array([1]),
508+
numpy.array([[1], [1]]),
509+
numpy.array([2]),
510+
numpy.array([[1], [2]]),
511+
numpy.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]),
512+
],
513+
)
514+
def test_array_equiv(a, b):
515+
result = dpnp.array_equiv(dpnp.array(a), dpnp.array(b))
516+
expected = numpy.array_equiv(a, b)
517+
518+
assert_equal(expected, result)
519+
520+
521+
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True, no_complex=True))
522+
def test_array_equiv_dtype(dtype):
523+
a = numpy.array([1, 2], dtype=dtype)
524+
b = numpy.array([1, 2], dtype=dtype)
525+
c = numpy.array([1, 3], dtype=dtype)
526+
527+
result = dpnp.array_equiv(dpnp.array(a), dpnp.array(b))
528+
expected = numpy.array_equiv(a, b)
529+
530+
assert_equal(expected, result)
531+
532+
result = dpnp.array_equiv(dpnp.array(a), dpnp.array(c))
533+
expected = numpy.array_equiv(a, c)
534+
535+
assert_equal(expected, result)
536+
537+
538+
@pytest.mark.parametrize("a", [numpy.array([1, 2]), numpy.array([1, 1])])
539+
def test_array_equiv_scalar(a):
540+
b = 1
541+
result = dpnp.array_equiv(dpnp.array(a), b)
542+
expected = numpy.array_equiv(a, b)
543+
544+
assert_equal(expected, result)
545+
546+
547+
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True, no_complex=True))
548+
@pytest.mark.parametrize("equal_nan", [True, False])
549+
def test_array_equal_dtype(dtype, equal_nan):
550+
a = numpy.array([1, 2], dtype=dtype)
551+
b = numpy.array([1, 2], dtype=dtype)
552+
c = numpy.array([1, 3], dtype=dtype)
553+
554+
result = dpnp.array_equal(dpnp.array(a), dpnp.array(b), equal_nan=equal_nan)
555+
expected = numpy.array_equal(a, b, equal_nan=equal_nan)
556+
557+
assert_equal(expected, result)
558+
559+
result = dpnp.array_equal(dpnp.array(a), dpnp.array(c), equal_nan=equal_nan)
560+
expected = numpy.array_equal(a, c, equal_nan=equal_nan)
561+
562+
assert_equal(expected, result)
563+
564+
565+
@pytest.mark.parametrize(
566+
"a",
567+
[
568+
numpy.array([1, 2]),
569+
numpy.array([1.0, numpy.nan]),
570+
numpy.array([1.0, numpy.inf]),
571+
],
572+
)
573+
def test_array_equal_same_arr(a):
574+
expected = numpy.array_equal(a, a)
575+
b = dpnp.array(a)
576+
result = dpnp.array_equal(b, b)
577+
assert_equal(expected, result)
578+
579+
expected = numpy.array_equal(a, a, equal_nan=True)
580+
result = dpnp.array_equal(b, b, equal_nan=True)
581+
assert_equal(expected, result)
582+
583+
584+
@pytest.mark.parametrize(
585+
"a",
586+
[
587+
numpy.array([1, 2]),
588+
numpy.array([1.0, numpy.nan]),
589+
numpy.array([1.0, numpy.inf]),
590+
],
591+
)
592+
def test_array_equal_nan(a):
593+
a = numpy.array([1.0, numpy.nan])
594+
b = numpy.array([1.0, 2.0])
595+
result = dpnp.array_equal(dpnp.array(a), dpnp.array(b), equal_nan=True)
596+
expected = numpy.array_equal(a, b, equal_nan=True)
597+
assert_equal(expected, result)

tests/test_sycl_queue.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -752,6 +752,8 @@ def test_2in_1out(func, data1, data2, device):
752752
@pytest.mark.parametrize(
753753
"op",
754754
[
755+
"array_equal",
756+
"array_equiv",
755757
"equal",
756758
"greater",
757759
"greater_equal",

tests/test_usm_type.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,8 @@ def test_coerced_usm_types_logic_op_1in(op, usm_type_x):
378378
@pytest.mark.parametrize(
379379
"op",
380380
[
381+
"array_equal",
382+
"array_equiv",
381383
"equal",
382384
"greater",
383385
"greater_equal",

0 commit comments

Comments
 (0)