Skip to content

Added dtype compatibility check for device for result_type() and can_cast() functions. #1053

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions dpctl/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
broadcast_to,
can_cast,
concat,
device_result_type,
expand_dims,
finfo,
flip,
Expand Down Expand Up @@ -137,4 +138,5 @@
"get_print_options",
"set_print_options",
"print_options",
"device_result_type",
]
49 changes: 48 additions & 1 deletion dpctl/tensor/_manipulation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ def stack(arrays, axis=0):
return res


def can_cast(from_, to, casting="safe"):
def can_cast(from_, to, casting="safe", device=None):
"""
can_cast(from: usm_ndarray or dtype, to: dtype) -> bool

Expand All @@ -454,6 +454,25 @@ def can_cast(from_, to, casting="safe"):

_supported_dtype([dtype_from, dtype_to])

if device is not None:
if isinstance(device, (dpctl.SyclQueue, dpt.Device)):
device = device.sycl_device
if not isinstance(device, dpctl.SyclDevice):
raise TypeError(f"Expected sycl_device type, got {type(device)}.")
if (
not device.has_aspect_fp16
and dtype_to == dpt.float16
or not device.has_aspect_fp64
and (dtype_to == dpt.float64 or dtype_to == dpt.complex128)
):
return False
if not device.has_aspect_fp64 and (
dtype_to == dpt.complex64
or dtype_to == dpt.float32
and dtype_from is not complex
):
return True

return np.can_cast(dtype_from, dtype_to, casting)


Expand All @@ -475,6 +494,34 @@ def result_type(*arrays_and_dtypes):
return np.result_type(*dtypes)


def device_result_type(device, *arrays_and_dtypes):
"""
device_result_type(device: sycl_device, arrays_and_dtypes: an arbitrary \
number usm_ndarrays or dtypes) -> dtype

Returns the dtype that results from applying the Type Promotion Rules to \
the arguments on current device.
"""
dt = result_type(*arrays_and_dtypes)

if device is not None:
if isinstance(device, (dpctl.SyclQueue, dpt.Device)):
device = device.sycl_device
if not isinstance(device, dpctl.SyclDevice):
raise TypeError(f"Expected sycl_device type, got {type(device)}.")
if (
dt == dpt.float16
and not device.has_aspect_fp16
or dt == dpt.float64
and not device.has_aspect_fp64
):
return dpt.float32
if dt == dpt.complex128 and not device.has_aspect_fp64:
return dpt.complex64

return dt


def iinfo(dtype):
"""
iinfo(dtype: integer data-type) -> iinfo_object
Expand Down