Skip to content

Commit 39931b2

Browse files
Handle possibilities for TypeError and BufferError
These may be hard to test
1 parent 22579f6 commit 39931b2

File tree

1 file changed

+35
-17
lines changed

1 file changed

+35
-17
lines changed

dpctl/tensor/_dlpack.pyx

Lines changed: 35 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -935,6 +935,22 @@ cpdef object from_dlpack_capsule(object py_caps):
935935
"The DLPack tensor resides on unsupported device."
936936
)
937937

938+
cdef usm_ndarray _to_usm_ary_from_host_blob(object host_blob, int device_id):
939+
root_device = dpctl.SyclDevice(str(<int>device_id))
940+
q = Device.create_device(root_device).sycl_queue
941+
np_ary = np.asarray(host_blob)
942+
dt = np_ary.dtype
943+
if dt.char in "dD" and q.sycl_device.has_aspect_fp64 is False:
944+
Xusm_dtype = (
945+
"float32" if dt.char == "d" else "complex64"
946+
)
947+
else:
948+
Xusm_dtype = dt
949+
usm_mem = dpmem.MemoryUSMDevice(np_ary.nbytes, queue=q)
950+
usm_ary = usm_ndarray(np_ary.shape, dtype=Xusm_dtype, buffer=usm_mem)
951+
usm_mem.copy_from_host(np.reshape(np_ary.view(dtype="u1"), -1))
952+
return usm_ary
953+
938954

939955
def from_dlpack(x, /, *, device=None, copy=None):
940956
""" from_dlpack(x, /, *, device=None, copy=None)
@@ -1036,34 +1052,36 @@ def from_dlpack(x, /, *, device=None, copy=None):
10361052
dlpack_capsule = dlpack_attr(max_version=get_build_dlpack_version(), dl_device=dl_device, copy=copy)
10371053
return from_dlpack_capsule(dlpack_capsule)
10381054
except TypeError:
1055+
# max_version/dl_device, copy keywords are not supported by __dlpack__
10391056
x_dldev = dlpack_dev_attr()
10401057
if (dl_device is None) or (dl_device == x_dldev):
10411058
dlpack_capsule = dlpack_attr()
10421059
return from_dlpack_capsule(dlpack_capsule)
10431060
# must copy via host
10441061
if copy is False:
1045-
raise ValueError(
1062+
raise BufferError(
10461063
"Importing data via DLPack requires copying, but copy=False was provided"
10471064
)
1065+
if x_dldev == (device_CPU, 0) and dl_device[0] == device_OneAPI:
1066+
host_blob = x
1067+
else:
1068+
raise BufferError(f"Can not import to requested device {dl_device}")
1069+
return _to_usm_ary_from_host_blob(host_blob, dl_device[1])
1070+
except BufferError as e:
1071+
# we are here, because dlpack_attr could not deal with requested dl_device,
1072+
# or copying was required
1073+
if copy is False:
1074+
raise BufferError(
1075+
"Importing data via DLPack requires copying, but copy=False was provided"
1076+
) from e
1077+
# must copy via host
10481078
if dl_device[0] != device_OneAPI:
1049-
raise ValueError(f"Can not import to requested device {dl_device}")
1079+
raise BufferError(f"Can not import to requested device {dl_device}")
1080+
x_dldev = dlpack_dev_attr()
10501081
if x_dldev == (device_CPU, 0):
10511082
host_blob = x
10521083
else:
1084+
# this would fail anyway
10531085
dlpack_capsule = dlpack_attr(max_version=(1, 0), dl_device=(device_CPU, 0), copy=copy)
10541086
host_blob = from_dlpack_capsule(dlpack_capsule)
1055-
device_id = dl_device[1]
1056-
root_device = dpctl.SyclDevice(str(<int>device_id))
1057-
q = Device.create_device(root_device).sycl_queue
1058-
np_ary = np.asarray(host_blob)
1059-
dt = np_ary.dtype
1060-
if dt.char in "dD" and q.sycl_device.has_aspect_fp64 is False:
1061-
Xusm_dtype = (
1062-
"float32" if dt.char == "d" else "complex64"
1063-
)
1064-
else:
1065-
Xusm_dtype = dt
1066-
usm_mem = dpmem.MemoryUSMDevice(np_ary.nbytes, queue=q)
1067-
usm_ary = usm_ndarray(np_ary.shape, dtype=Xusm_dtype, buffer=usm_mem)
1068-
usm_mem.copy_from_host(np.reshape(np_ary.view(dtype="u1"), -1))
1069-
return usm_ary
1087+
return _to_usm_ary_from_host_blob(host_blob, dl_device[1])

0 commit comments

Comments
 (0)