Skip to content

Commit 55a3da8

Browse files
committed
Added CFD
1 parent 4347c09 commit 55a3da8

File tree

3 files changed

+58
-4
lines changed

3 files changed

+58
-4
lines changed

dpnp/dpnp_iface_logic.py

Lines changed: 54 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
import numpy
5151

5252
import dpnp
53+
import dpnp.dpnp_utils as utils
5354
from dpnp.dpnp_algo.dpnp_elementwise_common import DPNPBinaryFunc, DPNPUnaryFunc
5455
from dpnp.dpnp_array import dpnp_array
5556

@@ -412,15 +413,39 @@ def array_equal(a1, a2, equal_nan=False):
412413
413414
"""
414415
dpnp.check_supported_arrays_type(a1, a2, scalar_type=True)
416+
if dpnp.isscalar(a1):
417+
usm_type_alloc = a2.usm_type
418+
sycl_queue_alloc = a2.sycl_queue
419+
a1 = dpnp.array(
420+
a1,
421+
dtype=dpnp.result_type(a1, a2),
422+
usm_type=usm_type_alloc,
423+
sycl_queue=sycl_queue_alloc,
424+
)
425+
elif dpnp.isscalar(a2):
426+
usm_type_alloc = a1.usm_type
427+
sycl_queue_alloc = a1.sycl_queue
428+
a2 = dpnp.array(
429+
a2,
430+
dtype=dpnp.result_type(a1, a2),
431+
usm_type=usm_type_alloc,
432+
sycl_queue=sycl_queue_alloc,
433+
)
434+
else:
435+
usm_type_alloc, sycl_queue_alloc = utils.get_usm_allocations([a1, a2])
415436

416437
if a1.shape != a2.shape:
417-
return dpnp.array(False)
438+
return dpnp.array(
439+
False, usm_type=usm_type_alloc, sycl_queue=sycl_queue_alloc
440+
)
418441

419442
if not equal_nan:
420443
return (a1 == a2).all()
421444

422445
if a1 is a2:
423-
return dpnp.array(True)
446+
return dpnp.array(
447+
True, usm_type=usm_type_alloc, sycl_queue=sycl_queue_alloc
448+
)
424449

425450
cannot_have_nan = (
426451
dpnp.issubdtype(a1, dpnp.bool) or dpnp.issubdtype(a1, dpnp.integer)
@@ -433,7 +458,9 @@ def array_equal(a1, a2, equal_nan=False):
433458
a1nan, a2nan = isnan(a1), isnan(a2)
434459
# NaNs occur at different locations
435460
if not (a1nan == a2nan).all():
436-
return dpnp.array(False)
461+
return dpnp.array(
462+
False, usm_type=usm_type_alloc, sycl_queue=sycl_queue_alloc
463+
)
437464
# Shapes of a1, a2 and masks are guaranteed to be consistent by this point
438465
return (a1[~a1nan] == a2[~a1nan]).all()
439466

@@ -489,10 +516,33 @@ def array_equiv(a1, a2):
489516
490517
"""
491518
dpnp.check_supported_arrays_type(a1, a2, scalar_type=True)
519+
if dpnp.isscalar(a1):
520+
usm_type_alloc = a2.usm_type
521+
sycl_queue_alloc = a2.sycl_queue
522+
a1 = dpnp.array(
523+
a1,
524+
dtype=dpnp.result_type(a1, a2),
525+
usm_type=usm_type_alloc,
526+
sycl_queue=sycl_queue_alloc,
527+
)
528+
elif dpnp.isscalar(a2):
529+
usm_type_alloc = a1.usm_type
530+
sycl_queue_alloc = a1.sycl_queue
531+
a2 = dpnp.array(
532+
a2,
533+
dtype=dpnp.result_type(a1, a2),
534+
usm_type=usm_type_alloc,
535+
sycl_queue=sycl_queue_alloc,
536+
)
537+
else:
538+
usm_type_alloc, sycl_queue_alloc = utils.get_usm_allocations([a1, a2])
539+
492540
try:
493541
dpnp.broadcast_arrays(a1, a2)
494542
except ValueError:
495-
return dpnp.array(False)
543+
return dpnp.array(
544+
False, usm_type=usm_type_alloc, sycl_queue=sycl_queue_alloc
545+
)
496546
return (a1 == a2).all()
497547

498548

tests/test_sycl_queue.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -746,6 +746,8 @@ def test_2in_1out(func, data1, data2, device):
746746
@pytest.mark.parametrize(
747747
"op",
748748
[
749+
"array_equal",
750+
"array_equiv",
749751
"equal",
750752
"greater",
751753
"greater_equal",

tests/test_usm_type.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,8 @@ def test_coerced_usm_types_logic_op_1in(op, usm_type_x):
378378
@pytest.mark.parametrize(
379379
"op",
380380
[
381+
"array_equal",
382+
"array_equiv",
381383
"equal",
382384
"greater",
383385
"greater_equal",

0 commit comments

Comments
 (0)