Skip to content

Commit 2134c84

Browse files
committed
Added dtype compatibility check for device for result_type() and can_cast() functions.
1 parent db8fe2f commit 2134c84

File tree

1 file changed

+46
-1
lines changed

1 file changed

+46
-1
lines changed

dpctl/tensor/_manipulation_functions.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -436,7 +436,7 @@ def stack(arrays, axis=0):
436436
return res
437437

438438

439-
def can_cast(from_, to, casting="safe"):
439+
def can_cast(from_, to, casting="safe", device=None):
440440
"""
441441
can_cast(from: usm_ndarray or dtype, to: dtype) -> bool
442442
@@ -454,6 +454,23 @@ def can_cast(from_, to, casting="safe"):
454454

455455
_supported_dtype([dtype_from, dtype_to])
456456

457+
if device is not None:
458+
if isinstance(device, (dpctl.SyclQueue, dpt._device.Device)):
459+
device = device.sycl_device
460+
if not isinstance(device, dpctl.SyclDevice):
461+
raise TypeError(f"Expected sycl_device type, got {type(device)}.")
462+
if (
463+
not device.has_aspect_fp16
464+
and to == dpt.float16
465+
or not device.has_aspect_fp64
466+
and (to == dpt.float64 or to == dpt.complex128)
467+
):
468+
return False
469+
if not device.has_aspect_fp64 and (
470+
to == dpt.complex64 or to == dpt.float32 and from_ is not complex
471+
):
472+
return True
473+
457474
return np.can_cast(dtype_from, dtype_to, casting)
458475

459476

@@ -475,6 +492,34 @@ def result_type(*arrays_and_dtypes):
475492
return np.result_type(*dtypes)
476493

477494

495+
def device_result_type(device, *arrays_and_dtypes):
496+
"""
497+
device_result_type(device: sycl_device, arrays_and_dtypes: an arbitrary \
498+
number usm_ndarrays or dtypes) -> dtype
499+
500+
Returns the dtype that results from applying the Type Promotion Rules to \
501+
the arguments on current device.
502+
"""
503+
dt = result_type(*arrays_and_dtypes)
504+
505+
if device is not None:
506+
if isinstance(device, (dpctl.SyclQueue, dpt._device.Device)):
507+
device = device.sycl_device
508+
if not isinstance(device, dpctl.SyclDevice):
509+
raise TypeError(f"Expected sycl_device type, got {type(device)}.")
510+
if (
511+
dt == dpt.float16
512+
and not device.has_aspect_fp16
513+
or dt == dpt.float64
514+
and not device.has_aspect_fp64
515+
):
516+
return dpt.float32
517+
if dt == dpt.complex128 and not device.has_aspect_fp64:
518+
return dpt.complex64
519+
520+
return dt
521+
522+
478523
def iinfo(dtype):
479524
"""
480525
iinfo(dtype: integer data-type) -> iinfo_object

0 commit comments

Comments
 (0)