@@ -436,7 +436,7 @@ def stack(arrays, axis=0):
436
436
return res
437
437
438
438
439
- def can_cast (from_ , to , casting = "safe" ):
439
+ def can_cast (from_ , to , casting = "safe" , device = None ):
440
440
"""
441
441
can_cast(from: usm_ndarray or dtype, to: dtype) -> bool
442
442
@@ -454,6 +454,23 @@ def can_cast(from_, to, casting="safe"):
454
454
455
455
_supported_dtype ([dtype_from , dtype_to ])
456
456
457
+ if device is not None :
458
+ if isinstance (device , (dpctl .SyclQueue , dpt ._device .Device )):
459
+ device = device .sycl_device
460
+ if not isinstance (device , dpctl .SyclDevice ):
461
+ raise TypeError (f"Expected sycl_device type, got { type (device )} ." )
462
+ if (
463
+ not device .has_aspect_fp16
464
+ and to == dpt .float16
465
+ or not device .has_aspect_fp64
466
+ and (to == dpt .float64 or to == dpt .complex128 )
467
+ ):
468
+ return False
469
+ if not device .has_aspect_fp64 and (
470
+ to == dpt .complex64 or to == dpt .float32 and from_ is not complex
471
+ ):
472
+ return True
473
+
457
474
return np .can_cast (dtype_from , dtype_to , casting )
458
475
459
476
@@ -475,6 +492,34 @@ def result_type(*arrays_and_dtypes):
475
492
return np .result_type (* dtypes )
476
493
477
494
495
+ def device_result_type (device , * arrays_and_dtypes ):
496
+ """
497
+ device_result_type(device: sycl_device, arrays_and_dtypes: an arbitrary \
498
+ number usm_ndarrays or dtypes) -> dtype
499
+
500
+ Returns the dtype that results from applying the Type Promotion Rules to \
501
+ the arguments on current device.
502
+ """
503
+ dt = result_type (* arrays_and_dtypes )
504
+
505
+ if device is not None :
506
+ if isinstance (device , (dpctl .SyclQueue , dpt ._device .Device )):
507
+ device = device .sycl_device
508
+ if not isinstance (device , dpctl .SyclDevice ):
509
+ raise TypeError (f"Expected sycl_device type, got { type (device )} ." )
510
+ if (
511
+ dt == dpt .float16
512
+ and not device .has_aspect_fp16
513
+ or dt == dpt .float64
514
+ and not device .has_aspect_fp64
515
+ ):
516
+ return dpt .float32
517
+ if dt == dpt .complex128 and not device .has_aspect_fp64 :
518
+ return dpt .complex64
519
+
520
+ return dt
521
+
522
+
478
523
def iinfo (dtype ):
479
524
"""
480
525
iinfo(dtype: integer data-type) -> iinfo_object
0 commit comments