Skip to content

Commit 68159fb

Browse files
committed
Added dtype compatibility check for device for result_type() and can_cast() functions.
1 parent 2134c84 commit 68159fb

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

dpctl/tensor/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
broadcast_to,
4949
can_cast,
5050
concat,
51+
device_result_type,
5152
expand_dims,
5253
finfo,
5354
flip,
@@ -137,4 +138,5 @@
137138
"get_print_options",
138139
"set_print_options",
139140
"print_options",
141+
"device_result_type",
140142
]

dpctl/tensor/_manipulation_functions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -455,7 +455,7 @@ def can_cast(from_, to, casting="safe", device=None):
455455
_supported_dtype([dtype_from, dtype_to])
456456

457457
if device is not None:
458-
if isinstance(device, (dpctl.SyclQueue, dpt._device.Device)):
458+
if isinstance(device, (dpctl.SyclQueue, dpt.Device)):
459459
device = device.sycl_device
460460
if not isinstance(device, dpctl.SyclDevice):
461461
raise TypeError(f"Expected sycl_device type, got {type(device)}.")
@@ -503,7 +503,7 @@ def device_result_type(device, *arrays_and_dtypes):
503503
dt = result_type(*arrays_and_dtypes)
504504

505505
if device is not None:
506-
if isinstance(device, (dpctl.SyclQueue, dpt._device.Device)):
506+
if isinstance(device, (dpctl.SyclQueue, dpt.Device)):
507507
device = device.sycl_device
508508
if not isinstance(device, dpctl.SyclDevice):
509509
raise TypeError(f"Expected sycl_device type, got {type(device)}.")

0 commit comments

Comments
 (0)