@@ -436,7 +436,7 @@ def stack(arrays, axis=0):
436
436
return res
437
437
438
438
439
- def can_cast (from_ , to , casting = "safe" ):
439
+ def can_cast (from_ , to , casting = "safe" , device = None ):
440
440
"""
441
441
can_cast(from: usm_ndarray or dtype, to: dtype) -> bool
442
442
@@ -454,6 +454,25 @@ def can_cast(from_, to, casting="safe"):
454
454
455
455
_supported_dtype ([dtype_from , dtype_to ])
456
456
457
+ if device is not None :
458
+ if isinstance (device , (dpctl .SyclQueue , dpt .Device )):
459
+ device = device .sycl_device
460
+ if not isinstance (device , dpctl .SyclDevice ):
461
+ raise TypeError (f"Expected sycl_device type, got { type (device )} ." )
462
+ if (
463
+ not device .has_aspect_fp16
464
+ and dtype_to == dpt .float16
465
+ or not device .has_aspect_fp64
466
+ and (dtype_to == dpt .float64 or dtype_to == dpt .complex128 )
467
+ ):
468
+ return False
469
+ if not device .has_aspect_fp64 and (
470
+ dtype_to == dpt .complex64
471
+ or dtype_to == dpt .float32
472
+ and dtype_from is not complex
473
+ ):
474
+ return True
475
+
457
476
return np .can_cast (dtype_from , dtype_to , casting )
458
477
459
478
@@ -475,6 +494,34 @@ def result_type(*arrays_and_dtypes):
475
494
return np .result_type (* dtypes )
476
495
477
496
497
+ def device_result_type (device , * arrays_and_dtypes ):
498
+ """
499
+ device_result_type(device: sycl_device, arrays_and_dtypes: an arbitrary \
500
+ number usm_ndarrays or dtypes) -> dtype
501
+
502
+ Returns the dtype that results from applying the Type Promotion Rules to \
503
+ the arguments on current device.
504
+ """
505
+ dt = result_type (* arrays_and_dtypes )
506
+
507
+ if device is not None :
508
+ if isinstance (device , (dpctl .SyclQueue , dpt .Device )):
509
+ device = device .sycl_device
510
+ if not isinstance (device , dpctl .SyclDevice ):
511
+ raise TypeError (f"Expected sycl_device type, got { type (device )} ." )
512
+ if (
513
+ dt == dpt .float16
514
+ and not device .has_aspect_fp16
515
+ or dt == dpt .float64
516
+ and not device .has_aspect_fp64
517
+ ):
518
+ return dpt .float32
519
+ if dt == dpt .complex128 and not device .has_aspect_fp64 :
520
+ return dpt .complex64
521
+
522
+ return dt
523
+
524
+
478
525
def iinfo (dtype ):
479
526
"""
480
527
iinfo(dtype: integer data-type) -> iinfo_object
0 commit comments