50
50
import numpy
51
51
52
52
import dpnp
53
+ import dpnp .dpnp_utils as utils
53
54
from dpnp .dpnp_algo .dpnp_elementwise_common import DPNPBinaryFunc , DPNPUnaryFunc
54
55
from dpnp .dpnp_array import dpnp_array
55
56
@@ -412,15 +413,39 @@ def array_equal(a1, a2, equal_nan=False):
412
413
413
414
"""
414
415
dpnp .check_supported_arrays_type (a1 , a2 , scalar_type = True )
416
+ if dpnp .isscalar (a1 ):
417
+ usm_type_alloc = a2 .usm_type
418
+ sycl_queue_alloc = a2 .sycl_queue
419
+ a1 = dpnp .array (
420
+ a1 ,
421
+ dtype = dpnp .result_type (a1 , a2 ),
422
+ usm_type = usm_type_alloc ,
423
+ sycl_queue = sycl_queue_alloc ,
424
+ )
425
+ elif dpnp .isscalar (a2 ):
426
+ usm_type_alloc = a1 .usm_type
427
+ sycl_queue_alloc = a1 .sycl_queue
428
+ a2 = dpnp .array (
429
+ a2 ,
430
+ dtype = dpnp .result_type (a1 , a2 ),
431
+ usm_type = usm_type_alloc ,
432
+ sycl_queue = sycl_queue_alloc ,
433
+ )
434
+ else :
435
+ usm_type_alloc , sycl_queue_alloc = utils .get_usm_allocations ([a1 , a2 ])
415
436
416
437
if a1 .shape != a2 .shape :
417
- return dpnp .array (False )
438
+ return dpnp .array (
439
+ False , usm_type = usm_type_alloc , sycl_queue = sycl_queue_alloc
440
+ )
418
441
419
442
if not equal_nan :
420
443
return (a1 == a2 ).all ()
421
444
422
445
if a1 is a2 :
423
- return dpnp .array (True )
446
+ return dpnp .array (
447
+ True , usm_type = usm_type_alloc , sycl_queue = sycl_queue_alloc
448
+ )
424
449
425
450
cannot_have_nan = (
426
451
dpnp .issubdtype (a1 , dpnp .bool ) or dpnp .issubdtype (a1 , dpnp .integer )
@@ -433,7 +458,9 @@ def array_equal(a1, a2, equal_nan=False):
433
458
a1nan , a2nan = isnan (a1 ), isnan (a2 )
434
459
# NaNs occur at different locations
435
460
if not (a1nan == a2nan ).all ():
436
- return dpnp .array (False )
461
+ return dpnp .array (
462
+ False , usm_type = usm_type_alloc , sycl_queue = sycl_queue_alloc
463
+ )
437
464
# Shapes of a1, a2 and masks are guaranteed to be consistent by this point
438
465
return (a1 [~ a1nan ] == a2 [~ a1nan ]).all ()
439
466
@@ -489,10 +516,33 @@ def array_equiv(a1, a2):
489
516
490
517
"""
491
518
dpnp .check_supported_arrays_type (a1 , a2 , scalar_type = True )
519
+ if dpnp .isscalar (a1 ):
520
+ usm_type_alloc = a2 .usm_type
521
+ sycl_queue_alloc = a2 .sycl_queue
522
+ a1 = dpnp .array (
523
+ a1 ,
524
+ dtype = dpnp .result_type (a1 , a2 ),
525
+ usm_type = usm_type_alloc ,
526
+ sycl_queue = sycl_queue_alloc ,
527
+ )
528
+ elif dpnp .isscalar (a2 ):
529
+ usm_type_alloc = a1 .usm_type
530
+ sycl_queue_alloc = a1 .sycl_queue
531
+ a2 = dpnp .array (
532
+ a2 ,
533
+ dtype = dpnp .result_type (a1 , a2 ),
534
+ usm_type = usm_type_alloc ,
535
+ sycl_queue = sycl_queue_alloc ,
536
+ )
537
+ else :
538
+ usm_type_alloc , sycl_queue_alloc = utils .get_usm_allocations ([a1 , a2 ])
539
+
492
540
try :
493
541
dpnp .broadcast_arrays (a1 , a2 )
494
542
except ValueError :
495
- return dpnp .array (False )
543
+ return dpnp .array (
544
+ False , usm_type = usm_type_alloc , sycl_queue = sycl_queue_alloc
545
+ )
496
546
return (a1 == a2 ).all ()
497
547
498
548
0 commit comments