|
52 | 52 | import dpnp
|
53 | 53 | from dpnp.dpnp_algo.dpnp_elementwise_common import DPNPBinaryFunc, DPNPUnaryFunc
|
54 | 54 |
|
| 55 | +from .dpnp_utils import get_usm_allocations |
| 56 | + |
55 | 57 | __all__ = [
|
56 | 58 | "all",
|
57 | 59 | "allclose",
|
58 | 60 | "any",
|
| 61 | + "array_equal", |
| 62 | + "array_equiv", |
59 | 63 | "equal",
|
60 | 64 | "greater",
|
61 | 65 | "greater_equal",
|
@@ -112,7 +116,7 @@ def all(a, /, axis=None, out=None, keepdims=False, *, where=True):
|
112 | 116 | Returns
|
113 | 117 | -------
|
114 | 118 | out : dpnp.ndarray
|
115 |
| - An array with a data type of `bool` |
| 119 | + An array with a data type of `bool`. |
116 | 120 | containing the results of the logical AND reduction is returned
|
117 | 121 | unless `out` is specified. Otherwise, a reference to `out` is returned.
|
118 | 122 | The result has the same shape as `a` if `axis` is not ``None``
|
@@ -276,7 +280,7 @@ def any(a, /, axis=None, out=None, keepdims=False, *, where=True):
|
276 | 280 | Returns
|
277 | 281 | -------
|
278 | 282 | out : dpnp.ndarray
|
279 |
| - An array with a data type of `bool` |
| 283 | + An array with a data type of `bool`. |
280 | 284 | containing the results of the logical OR reduction is returned
|
281 | 285 | unless `out` is specified. Otherwise, a reference to `out` is returned.
|
282 | 286 | The result has the same shape as `a` if `axis` is not ``None``
|
@@ -337,6 +341,191 @@ def any(a, /, axis=None, out=None, keepdims=False, *, where=True):
|
337 | 341 | return dpnp.get_result_array(usm_res, out)
|
338 | 342 |
|
339 | 343 |
|
| 344 | +def array_equal(a1, a2, equal_nan=False): |
| 345 | + """ |
| 346 | + ``True`` if two arrays have the same shape and elements, ``False`` |
| 347 | + otherwise. |
| 348 | +
|
| 349 | + For full documentation refer to :obj:`numpy.array_equal`. |
| 350 | +
|
| 351 | + Parameters |
| 352 | + ---------- |
| 353 | + a1 : {dpnp.ndarray, usm_ndarray, scalar} |
| 354 | + First input array. |
| 355 | + Both inputs `x1` and `x2` can not be scalars at the same time. |
| 356 | + a2 : {dpnp.ndarray, usm_ndarray, scalar} |
| 357 | + Second input array. |
| 358 | + Both inputs `x1` and `x2` can not be scalars at the same time. |
| 359 | + equal_nan : bool, optional |
| 360 | + Whether to compare ``NaNs`` as equal. If the dtype of `a1` and `a2` is |
| 361 | + complex, values will be considered equal if either the real or the |
| 362 | + imaginary component of a given value is ``NaN``. |
| 363 | + Default: ``False``. |
| 364 | +
|
| 365 | + Returns |
| 366 | + ------- |
| 367 | + b : dpnp.ndarray |
| 368 | + An array with a data type of `bool`. |
| 369 | + Returns ``True`` if the arrays are equal. |
| 370 | +
|
| 371 | + See Also |
| 372 | + -------- |
| 373 | + :obj:`dpnp.allclose`: Returns ``True`` if two arrays are element-wise equal |
| 374 | + within a tolerance. |
| 375 | + :obj:`dpnp.array_equiv`: Returns ``True`` if input arrays are shape |
| 376 | + consistent and all elements equal. |
| 377 | +
|
| 378 | + Examples |
| 379 | + -------- |
| 380 | + >>> import dpnp as np |
| 381 | + >>> a = np.array([1, 2]) |
| 382 | + >>> b = np.array([1, 2]) |
| 383 | + >>> np.array_equal(a, b) |
| 384 | + array(True) |
| 385 | +
|
| 386 | + >>> b = np.array([1, 2, 3]) |
| 387 | + >>> np.array_equal(a, b) |
| 388 | + array(False) |
| 389 | +
|
| 390 | + >>> b = np.array([1, 4]) |
| 391 | + >>> np.array_equal(a, b) |
| 392 | + array(False) |
| 393 | +
|
| 394 | + >>> a = np.array([1, np.nan]) |
| 395 | + >>> np.array_equal(a, a) |
| 396 | + array(False) |
| 397 | +
|
| 398 | + >>> np.array_equal(a, a, equal_nan=True) |
| 399 | + array(True) |
| 400 | +
|
| 401 | + When ``equal_nan`` is ``True``, complex values with nan components are |
| 402 | + considered equal if either the real *or* the imaginary components are |
| 403 | + ``NaNs``. |
| 404 | +
|
| 405 | + >>> a = np.array([1 + 1j]) |
| 406 | + >>> b = a.copy() |
| 407 | + >>> a.real = np.nan |
| 408 | + >>> b.imag = np.nan |
| 409 | + >>> np.array_equal(a, b, equal_nan=True) |
| 410 | + array(True) |
| 411 | +
|
| 412 | + """ |
| 413 | + |
| 414 | + dpnp.check_supported_arrays_type(a1, a2, scalar_type=True) |
| 415 | + if dpnp.isscalar(a1): |
| 416 | + usm_type_alloc = a2.usm_type |
| 417 | + sycl_queue_alloc = a2.sycl_queue |
| 418 | + a1 = dpnp.array( |
| 419 | + a1, |
| 420 | + dtype=dpnp.result_type(a1, a2), |
| 421 | + usm_type=usm_type_alloc, |
| 422 | + sycl_queue=sycl_queue_alloc, |
| 423 | + ) |
| 424 | + elif dpnp.isscalar(a2): |
| 425 | + usm_type_alloc = a1.usm_type |
| 426 | + sycl_queue_alloc = a1.sycl_queue |
| 427 | + a2 = dpnp.array( |
| 428 | + a2, |
| 429 | + dtype=dpnp.result_type(a1, a2), |
| 430 | + usm_type=usm_type_alloc, |
| 431 | + sycl_queue=sycl_queue_alloc, |
| 432 | + ) |
| 433 | + else: |
| 434 | + usm_type_alloc, sycl_queue_alloc = get_usm_allocations([a1, a2]) |
| 435 | + |
| 436 | + if a1.shape != a2.shape: |
| 437 | + return dpnp.array( |
| 438 | + False, usm_type=usm_type_alloc, sycl_queue=sycl_queue_alloc |
| 439 | + ) |
| 440 | + |
| 441 | + if not equal_nan: |
| 442 | + return (a1 == a2).all() |
| 443 | + |
| 444 | + if a1 is a2: |
| 445 | + # NaN will compare equal so an array will compare equal to itself |
| 446 | + return dpnp.array( |
| 447 | + True, usm_type=usm_type_alloc, sycl_queue=sycl_queue_alloc |
| 448 | + ) |
| 449 | + |
| 450 | + if not ( |
| 451 | + dpnp.issubdtype(a1, dpnp.inexact) or dpnp.issubdtype(a2, dpnp.inexact) |
| 452 | + ): |
| 453 | + return (a1 == a2).all() |
| 454 | + |
| 455 | + # Handling NaN values if equal_nan is True |
| 456 | + a1nan, a2nan = isnan(a1), isnan(a2) |
| 457 | + # NaNs occur at different locations |
| 458 | + if not (a1nan == a2nan).all(): |
| 459 | + return dpnp.array( |
| 460 | + False, usm_type=usm_type_alloc, sycl_queue=sycl_queue_alloc |
| 461 | + ) |
| 462 | + # Shapes of a1, a2 and masks are guaranteed to be consistent by this point |
| 463 | + return (a1[~a1nan] == a2[~a1nan]).all() |
| 464 | + |
| 465 | + |
| 466 | +def array_equiv(a1, a2): |
| 467 | + """ |
| 468 | + Returns ``True`` if input arrays are shape consistent and all elements |
| 469 | + equal. |
| 470 | +
|
| 471 | + Shape consistent means they are either the same shape, or one input array |
| 472 | + can be broadcasted to create the same shape as the other one. |
| 473 | +
|
| 474 | + For full documentation refer to :obj:`numpy.array_equiv`. |
| 475 | +
|
| 476 | + Parameters |
| 477 | + ---------- |
| 478 | + a1 : {dpnp.ndarray, usm_ndarray, scalar} |
| 479 | + First input array. |
| 480 | + Both inputs `x1` and `x2` can not be scalars at the same time. |
| 481 | + a2 : {dpnp.ndarray, usm_ndarray, scalar} |
| 482 | + Second input array. |
| 483 | + Both inputs `x1` and `x2` can not be scalars at the same time. |
| 484 | +
|
| 485 | + Returns |
| 486 | + ------- |
| 487 | + out : dpnp.ndarray |
| 488 | + An array with a data type of `bool`. |
| 489 | + ``True`` if equivalent, ``False`` otherwise. |
| 490 | +
|
| 491 | + Examples |
| 492 | + -------- |
| 493 | + >>> import dpnp as np |
| 494 | + >>> a = np.array([1, 2]) |
| 495 | + >>> b = np.array([1, 2]) |
| 496 | + >>> c = np.array([1, 3]) |
| 497 | + >>> np.array_equiv(a, b) |
| 498 | + array(True) |
| 499 | + >>> np.array_equiv(a, c) |
| 500 | + array(False) |
| 501 | +
|
| 502 | + Showing the shape equivalence: |
| 503 | +
|
| 504 | + >>> b = np.array([[1, 2], [1, 2]]) |
| 505 | + >>> c = np.array([[1, 2, 1, 2], [1, 2, 1, 2]]) |
| 506 | + >>> np.array_equiv(a, b) |
| 507 | + array(True) |
| 508 | + >>> np.array_equiv(a, c) |
| 509 | + array(False) |
| 510 | +
|
| 511 | + >>> b = np.array([[1, 2], [1, 3]]) |
| 512 | + >>> np.array_equiv(a, b) |
| 513 | + array(False) |
| 514 | +
|
| 515 | + """ |
| 516 | + |
| 517 | + dpnp.check_supported_arrays_type(a1, a2, scalar_type=True) |
| 518 | + if not dpnp.isscalar(a1) and not dpnp.isscalar(a2): |
| 519 | + usm_type_alloc, sycl_queue_alloc = get_usm_allocations([a1, a2]) |
| 520 | + try: |
| 521 | + dpnp.broadcast_arrays(a1, a2) |
| 522 | + except ValueError: |
| 523 | + return dpnp.array( |
| 524 | + False, usm_type=usm_type_alloc, sycl_queue=sycl_queue_alloc |
| 525 | + ) |
| 526 | + return (a1 == a2).all() |
| 527 | + |
| 528 | + |
340 | 529 | _EQUAL_DOCSTRING = """
|
341 | 530 | Calculates equality test results for each element `x1_i` of the input array `x1`
|
342 | 531 | with the respective element `x2_i` of the input array `x2`.
|
|
0 commit comments