Skip to content

Commit 984dd15

Browse files
committed
Added dtype compatibility check for device for result_type() and can_cast() functions.
1 parent db8fe2f commit 984dd15

File tree

2 files changed

+50
-1
lines changed

2 files changed

+50
-1
lines changed

dpctl/tensor/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
broadcast_to,
4949
can_cast,
5050
concat,
51+
device_result_type,
5152
expand_dims,
5253
finfo,
5354
flip,
@@ -137,4 +138,5 @@
137138
"get_print_options",
138139
"set_print_options",
139140
"print_options",
141+
"device_result_type",
140142
]

dpctl/tensor/_manipulation_functions.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -436,7 +436,7 @@ def stack(arrays, axis=0):
436436
return res
437437

438438

439-
def can_cast(from_, to, casting="safe"):
439+
def can_cast(from_, to, casting="safe", device=None):
440440
"""
441441
can_cast(from: usm_ndarray or dtype, to: dtype) -> bool
442442
@@ -454,6 +454,25 @@ def can_cast(from_, to, casting="safe"):
454454

455455
_supported_dtype([dtype_from, dtype_to])
456456

457+
if device is not None:
458+
if isinstance(device, (dpctl.SyclQueue, dpt.Device)):
459+
device = device.sycl_device
460+
if not isinstance(device, dpctl.SyclDevice):
461+
raise TypeError(f"Expected sycl_device type, got {type(device)}.")
462+
if (
463+
not device.has_aspect_fp16
464+
and dtype_to == dpt.float16
465+
or not device.has_aspect_fp64
466+
and (dtype_to == dpt.float64 or dtype_to == dpt.complex128)
467+
):
468+
return False
469+
if not device.has_aspect_fp64 and (
470+
dtype_to == dpt.complex64
471+
or dtype_to == dpt.float32
472+
and dtype_from is not complex
473+
):
474+
return True
475+
457476
return np.can_cast(dtype_from, dtype_to, casting)
458477

459478

@@ -475,6 +494,34 @@ def result_type(*arrays_and_dtypes):
475494
return np.result_type(*dtypes)
476495

477496

497+
def device_result_type(device, *arrays_and_dtypes):
498+
"""
499+
device_result_type(device: sycl_device, arrays_and_dtypes: an arbitrary \
500+
number usm_ndarrays or dtypes) -> dtype
501+
502+
Returns the dtype that results from applying the Type Promotion Rules to \
503+
the arguments on current device.
504+
"""
505+
dt = result_type(*arrays_and_dtypes)
506+
507+
if device is not None:
508+
if isinstance(device, (dpctl.SyclQueue, dpt.Device)):
509+
device = device.sycl_device
510+
if not isinstance(device, dpctl.SyclDevice):
511+
raise TypeError(f"Expected sycl_device type, got {type(device)}.")
512+
if (
513+
dt == dpt.float16
514+
and not device.has_aspect_fp16
515+
or dt == dpt.float64
516+
and not device.has_aspect_fp64
517+
):
518+
return dpt.float32
519+
if dt == dpt.complex128 and not device.has_aspect_fp64:
520+
return dpt.complex64
521+
522+
return dt
523+
524+
478525
def iinfo(dtype):
479526
"""
480527
iinfo(dtype: integer data-type) -> iinfo_object

0 commit comments

Comments
 (0)