Skip to content

Commit 4347c09

Browse files
committed
Implement dpnp.array_equal and dpnp.array_equiv
1 parent 0c3dfe5 commit 4347c09

File tree

4 files changed

+294
-22
lines changed

4 files changed

+294
-22
lines changed

dpnp/dpnp_iface.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@
5656

5757
__all__ = [
5858
"are_same_logical_tensors",
59-
"array_equal",
6059
"asnumpy",
6160
"astype",
6261
"as_usm_ndarray",
@@ -173,24 +172,6 @@ def are_same_logical_tensors(ar1, ar2):
173172
)
174173

175174

176-
def array_equal(a1, a2, equal_nan=False):
177-
"""
178-
True if two arrays have the same shape and elements, False otherwise.
179-
180-
For full documentation refer to :obj:`numpy.array_equal`.
181-
182-
See Also
183-
--------
184-
:obj:`dpnp.allclose` : Returns True if two arrays are element-wise equal
185-
within a tolerance.
186-
:obj:`dpnp.array_equiv` : Returns True if input arrays are shape consistent
187-
and all elements equal.
188-
189-
"""
190-
191-
return numpy.array_equal(a1, a2, equal_nan=equal_nan)
192-
193-
194175
def asnumpy(a, order="C"):
195176
"""
196177
Returns the NumPy array with input data.

dpnp/dpnp_iface_logic.py

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@
5757
"all",
5858
"allclose",
5959
"any",
60+
"array_equal",
61+
"array_equiv",
6062
"equal",
6163
"greater",
6264
"greater_equal",
@@ -342,6 +344,158 @@ def any(a, /, axis=None, out=None, keepdims=False, *, where=True):
342344
return result
343345

344346

347+
def array_equal(a1, a2, equal_nan=False):
348+
"""
349+
``True`` if two arrays have the same shape and elements, ``False``
350+
otherwise.
351+
352+
Parameters
353+
----------
354+
a1 : {dpnp.ndarray, usm_ndarray, scalar}
355+
First input array, expected to have numeric data type.
356+
Both inputs `x1` and `x2` can not be scalars at the same time.
357+
a2 : {dpnp.ndarray, usm_ndarray, scalar}
358+
Second input array, also expected to have numeric data type.
359+
Both inputs `x1` and `x2` can not be scalars at the same time.
360+
equal_nan : bool
361+
Whether to compare ``NaNs`` as equal. If the dtype of `a1` and `a2` is
362+
complex, values will be considered equal if either the real or the
363+
imaginary component of a given value is ``NaNs``.
364+
365+
Returns
366+
-------
367+
b : dpnp.ndarray
368+
An array with a data type of `bool`
369+
Returns ``True`` if the arrays are equal.
370+
371+
See Also
372+
--------
373+
:obj:`dpnp.allclose`: Returns ``True`` if two arrays are element-wise equal
374+
within a tolerance.
375+
:obj:`dpnp.array_equiv`: Returns ``True`` if input arrays are shape
376+
consistent and all elements equal.
377+
378+
Examples
379+
--------
380+
>>> import dpnp as np
381+
>>> a = np.array([1, 2])
382+
>>> b = np.array([1, 2])
383+
>>> np.array_equal(a, b)
384+
array(True)
385+
386+
>>> a = np.array([1, 2])
387+
>>> b = np.array([1, 2, 3])
388+
>>> np.array_equal(a, b)
389+
array(False)
390+
391+
>>> a = np.array([1, 2])
392+
>>> b = np.array([1, 4])
393+
>>> np.array_equal(a, b)
394+
array(False)
395+
396+
>>> a = np.array([1, np.nan])
397+
>>> np.array_equal(a, a)
398+
array(False)
399+
400+
>>> np.array_equal(a, a, equal_nan=True)
401+
array(True)
402+
403+
When ``equal_nan`` is ``True``, complex values with nan components are
404+
considered equal if either the real *or* the imaginary components are nan.
405+
406+
>>> a = np.array([1 + 1j])
407+
>>> b = a.copy()
408+
>>> a.real = np.nan
409+
>>> b.imag = np.nan
410+
>>> np.array_equal(a, b, equal_nan=True)
411+
array(True)
412+
413+
"""
414+
dpnp.check_supported_arrays_type(a1, a2, scalar_type=True)
415+
416+
if a1.shape != a2.shape:
417+
return dpnp.array(False)
418+
419+
if not equal_nan:
420+
return (a1 == a2).all()
421+
422+
if a1 is a2:
423+
return dpnp.array(True)
424+
425+
cannot_have_nan = (
426+
dpnp.issubdtype(a1, dpnp.bool) or dpnp.issubdtype(a1, dpnp.integer)
427+
) and (dpnp.issubdtype(a2, dpnp.bool) or dpnp.issubdtype(a2, dpnp.integer))
428+
429+
if cannot_have_nan:
430+
return (a1 == a2).all()
431+
432+
# Handling NaN values if equal_nan is True
433+
a1nan, a2nan = isnan(a1), isnan(a2)
434+
# NaNs occur at different locations
435+
if not (a1nan == a2nan).all():
436+
return dpnp.array(False)
437+
# Shapes of a1, a2 and masks are guaranteed to be consistent by this point
438+
return (a1[~a1nan] == a2[~a1nan]).all()
439+
440+
441+
def array_equiv(a1, a2):
442+
"""
443+
Returns ``True`` if input arrays are shape consistent and all elements
444+
equal.
445+
446+
Shape consistent means they are either the same shape, or one input array
447+
can be broadcasted to create the same shape as the other one.
448+
449+
Parameters
450+
----------
451+
a1 : {dpnp.ndarray, usm_ndarray, scalar}
452+
First input array, expected to have numeric data type.
453+
Both inputs `x1` and `x2` can not be scalars at the same time.
454+
a2 : {dpnp.ndarray, usm_ndarray, scalar}
455+
Second input array, also expected to have numeric data type.
456+
Both inputs `x1` and `x2` can not be scalars at the same time.
457+
458+
Returns
459+
-------
460+
out : dpnp.ndarray
461+
An array with a data type of `bool`
462+
``True`` if equivalent, ``False`` otherwise.
463+
464+
Examples
465+
--------
466+
>>> import dpnp as np
467+
>>> a = np.array([1, 2])
468+
>>> b = np.array([1, 2])
469+
>>> c = np.array([1, 3])
470+
>>> np.array_equiv(a, b)
471+
array(True)
472+
>>> np.array_equiv(a, c)
473+
array(False)
474+
475+
Showing the shape equivalence:
476+
477+
>>> a = np.array([1, 2])
478+
>>> b = np.array([[1, 2], [1, 2]])
479+
>>> c = np.array([[1, 2, 1, 2], [1, 2, 1, 2]])
480+
>>> np.array_equiv(a, b)
481+
array(True)
482+
>>> np.array_equiv(a, c)
483+
array(False)
484+
485+
>>> a = np.array([1, 2])
486+
>>> b = np.array([[1, 2], [1, 3]])
487+
>>> np.array_equiv(a, b)
488+
array(False)
489+
490+
"""
491+
dpnp.check_supported_arrays_type(a1, a2, scalar_type=True)
492+
try:
493+
dpnp.broadcast_arrays(a1, a2)
494+
except ValueError:
495+
return dpnp.array(False)
496+
return (a1 == a2).all()
497+
498+
345499
_EQUAL_DOCSTRING = """
346500
Calculates equality test results for each element `x1_i` of the input array `x1`
347501
with the respective element `x2_i` of the input array `x2`.

tests/test_logic.py

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -494,3 +494,143 @@ def test_isclose(dtype, rtol, atol):
494494
np_res = numpy.isclose(a, b, 1e-05, 1e-08)
495495
dpnp_res = dpnp.isclose(dpnp_a, dpnp_b, rtol, atol)
496496
assert_allclose(dpnp_res, np_res)
497+
498+
499+
def _test_array_equal_parametrizations():
500+
"""
501+
we pre-create arrays as we sometime want to pass the same instance
502+
and sometime not. Passing the same instances may not mean the array are
503+
equal, especially when containing None
504+
"""
505+
# those are 0-d arrays, it used to be a special case
506+
# where (e0 == e0).all() would raise
507+
e0 = dpnp.array(0, dtype="i4")
508+
e1 = dpnp.array(1, dtype="f4")
509+
# x,y, nan_equal, expected_result
510+
yield (e0, e0.copy(), None, True)
511+
yield (e0, e0.copy(), False, True)
512+
yield (e0, e0.copy(), True, True)
513+
514+
#
515+
yield (e1, e1.copy(), None, True)
516+
yield (e1, e1.copy(), False, True)
517+
yield (e1, e1.copy(), True, True)
518+
519+
# Non-nanable – those cannot hold nans
520+
a12 = dpnp.array([1, 2])
521+
a12b = a12.copy()
522+
a123 = dpnp.array([1, 2, 3])
523+
a13 = dpnp.array([1, 3])
524+
a34 = dpnp.array([3, 4])
525+
526+
yield (a12, a12b, None, True)
527+
yield (a12, a12, None, True)
528+
yield (a12, a123, None, False)
529+
yield (a12, a34, None, False)
530+
yield (a12, a13, None, False)
531+
532+
# Non-float dtype - equal_nan should have no effect,
533+
yield (a123, a123, None, True)
534+
yield (a123, a123, False, True)
535+
yield (a123, a123, True, True)
536+
yield (a123, a123.copy(), None, True)
537+
yield (a123, a123.copy(), False, True)
538+
yield (a123, a123.copy(), True, True)
539+
yield (a123.astype("f4"), a123.astype("f4"), None, True)
540+
yield (a123.astype("f4"), a123.astype("f4"), False, True)
541+
yield (a123.astype("f4"), a123.astype("f4"), True, True)
542+
543+
# these can hold None
544+
b1 = dpnp.array([1, 2, dpnp.nan])
545+
b2 = dpnp.array([1, dpnp.nan, 2])
546+
b3 = dpnp.array([1, 2, dpnp.inf])
547+
b4 = dpnp.array(dpnp.nan)
548+
549+
# instances are the same
550+
yield (b1, b1, None, False)
551+
yield (b1, b1, False, False)
552+
yield (b1, b1, True, True)
553+
554+
# equal but not same instance
555+
yield (b1, b1.copy(), None, False)
556+
yield (b1, b1.copy(), False, False)
557+
yield (b1, b1.copy(), True, True)
558+
559+
# same once stripped of Nan
560+
yield (b1, b2, None, False)
561+
yield (b1, b2, False, False)
562+
yield (b1, b2, True, False)
563+
564+
# nan's not conflated with inf's
565+
yield (b1, b3, None, False)
566+
yield (b1, b3, False, False)
567+
yield (b1, b3, True, False)
568+
569+
# all Nan
570+
yield (b4, b4, None, False)
571+
yield (b4, b4, False, False)
572+
yield (b4, b4, True, True)
573+
yield (b4, b4.copy(), None, False)
574+
yield (b4, b4.copy(), False, False)
575+
yield (b4, b4.copy(), True, True)
576+
577+
# Multi-dimensional array
578+
md1 = dpnp.array([[0, 1], [dpnp.nan, 1]])
579+
580+
yield (md1, md1, None, False)
581+
yield (md1, md1, False, False)
582+
yield (md1, md1, True, True)
583+
yield (md1, md1.copy(), None, False)
584+
yield (md1, md1.copy(), False, False)
585+
yield (md1, md1.copy(), True, True)
586+
# both complexes are nan+nan.j but the same instance
587+
cplx1, cplx2 = [dpnp.array([dpnp.nan + dpnp.nan * 1j])] * 2
588+
589+
# only real or img are nan.
590+
cplx3 = dpnp.array(1 + dpnp.nan * 1j)
591+
cplx4 = dpnp.array(dpnp.nan + 1j)
592+
593+
# Complex values
594+
yield (cplx1, cplx2, None, False)
595+
yield (cplx1, cplx2, False, False)
596+
yield (cplx1, cplx2, True, True)
597+
598+
# Complex values, 1+nan, nan+1j
599+
yield (cplx3, cplx4, None, False)
600+
yield (cplx3, cplx4, False, False)
601+
yield (cplx3, cplx4, True, True)
602+
603+
604+
class TestArrayComparisons:
605+
@pytest.mark.parametrize(
606+
"bx,by,equal_nan,expected", _test_array_equal_parametrizations()
607+
)
608+
def test_array_equal_equal_nan(self, bx, by, equal_nan, expected):
609+
if equal_nan is None:
610+
res = dpnp.array_equal(bx, by)
611+
else:
612+
res = dpnp.array_equal(bx, by, equal_nan=equal_nan)
613+
assert_equal(res, dpnp.array(expected))
614+
615+
def test_array_equiv(self):
616+
res = dpnp.array_equiv(dpnp.array([1, 2]), dpnp.array([1, 2]))
617+
assert res
618+
res = dpnp.array_equiv(dpnp.array([1, 2]), dpnp.array([1, 2, 3]))
619+
assert not res
620+
res = dpnp.array_equiv(dpnp.array([1, 2]), dpnp.array([3, 4]))
621+
assert not res
622+
res = dpnp.array_equiv(dpnp.array([1, 2]), dpnp.array([1, 3]))
623+
assert not res
624+
625+
res = dpnp.array_equiv(dpnp.array([1, 1]), dpnp.array([1]))
626+
assert res
627+
res = dpnp.array_equiv(dpnp.array([1, 1]), dpnp.array([[1], [1]]))
628+
assert res
629+
res = dpnp.array_equiv(dpnp.array([1, 2]), dpnp.array([2]))
630+
assert not res
631+
res = dpnp.array_equiv(dpnp.array([1, 2]), dpnp.array([[1], [2]]))
632+
assert not res
633+
res = dpnp.array_equiv(
634+
dpnp.array([1, 2]), dpnp.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
635+
)
636+
assert not res

tests/third_party/cupy/logic_tests/test_comparison.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,6 @@ def test_array_equal_diff_length(self, xp, dtype):
102102
@testing.with_requires("numpy>=1.19")
103103
@testing.for_float_dtypes()
104104
@testing.numpy_cupy_equal()
105-
@pytest.mark.skip("Not supported yet")
106105
def test_array_equal_infinite_equal_nan(self, xp, dtype):
107106
nan = float("nan")
108107
inf = float("inf")
@@ -114,7 +113,6 @@ def test_array_equal_infinite_equal_nan(self, xp, dtype):
114113
@testing.with_requires("numpy>=1.19")
115114
@testing.for_complex_dtypes()
116115
@testing.numpy_cupy_equal()
117-
@pytest.mark.skip("Not supported yet")
118116
def test_array_equal_complex_equal_nan(self, xp, dtype):
119117
a = xp.array([1 + 2j], dtype=dtype)
120118
b = a.copy()
@@ -141,7 +139,6 @@ def test_array_equal_broadcast_not_allowed(self, xp):
141139
return xp.array_equal(a, b)
142140

143141

144-
@pytest.mark.skip("dpnp.array_equiv() is not implemented yet")
145142
class TestArrayEquiv(unittest.TestCase):
146143

147144
@testing.for_all_dtypes()

0 commit comments

Comments
 (0)