Skip to content

Commit 11195b8

Browse files
Handle possibilities for TypeError and BufferError
These may be hard to test
1 parent 35cb068 commit 11195b8

File tree

1 file changed

+35
-17
lines changed

1 file changed

+35
-17
lines changed

dpctl/tensor/_dlpack.pyx

Lines changed: 35 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -930,6 +930,22 @@ cpdef object from_dlpack_capsule(object py_caps):
930930
"The DLPack tensor resides on unsupported device."
931931
)
932932

933+
cdef usm_ndarray _to_usm_ary_from_host_blob(object host_blob, int device_id):
934+
root_device = dpctl.SyclDevice(str(<int>device_id))
935+
q = Device.create_device(root_device).sycl_queue
936+
np_ary = np.asarray(host_blob)
937+
dt = np_ary.dtype
938+
if dt.char in "dD" and q.sycl_device.has_aspect_fp64 is False:
939+
Xusm_dtype = (
940+
"float32" if dt.char == "d" else "complex64"
941+
)
942+
else:
943+
Xusm_dtype = dt
944+
usm_mem = dpmem.MemoryUSMDevice(np_ary.nbytes, queue=q)
945+
usm_ary = usm_ndarray(np_ary.shape, dtype=Xusm_dtype, buffer=usm_mem)
946+
usm_mem.copy_from_host(np.reshape(np_ary.view(dtype="u1"), -1))
947+
return usm_ary
948+
933949

934950
def from_dlpack(x, /, *, device=None, copy=None):
935951
""" from_dlpack(x, /, *, device=None, copy=None)
@@ -1031,34 +1047,36 @@ def from_dlpack(x, /, *, device=None, copy=None):
10311047
dlpack_capsule = dlpack_attr(max_version=get_build_dlpack_version(), dl_device=dl_device, copy=copy)
10321048
return from_dlpack_capsule(dlpack_capsule)
10331049
except TypeError:
1050+
# max_version/dl_device, copy keywords are not supported by __dlpack__
10341051
x_dldev = dlpack_dev_attr()
10351052
if (dl_device is None) or (dl_device == x_dldev):
10361053
dlpack_capsule = dlpack_attr()
10371054
return from_dlpack_capsule(dlpack_capsule)
10381055
# must copy via host
10391056
if copy is False:
1040-
raise ValueError(
1057+
raise BufferError(
10411058
"Importing data via DLPack requires copying, but copy=False was provided"
10421059
)
1060+
if x_dldev == (device_CPU, 0) and dl_device[0] == device_OneAPI:
1061+
host_blob = x
1062+
else:
1063+
raise BufferError(f"Can not import to requested device {dl_device}")
1064+
return _to_usm_ary_from_host_blob(host_blob, dl_device[1])
1065+
except BufferError as e:
1066+
# we are here, because dlpack_attr could not deal with requested dl_device,
1067+
# or copying was required
1068+
if copy is False:
1069+
raise BufferError(
1070+
"Importing data via DLPack requires copying, but copy=False was provided"
1071+
) from e
1072+
# must copy via host
10431073
if dl_device[0] != device_OneAPI:
1044-
raise ValueError(f"Can not import to requested device {dl_device}")
1074+
raise BufferError(f"Can not import to requested device {dl_device}")
1075+
x_dldev = dlpack_dev_attr()
10451076
if x_dldev == (device_CPU, 0):
10461077
host_blob = x
10471078
else:
1079+
# this would fail anyway
10481080
dlpack_capsule = dlpack_attr(max_version=(1, 0), dl_device=(device_CPU, 0), copy=copy)
10491081
host_blob = from_dlpack_capsule(dlpack_capsule)
1050-
device_id = dl_device[1]
1051-
root_device = dpctl.SyclDevice(str(<int>device_id))
1052-
q = Device.create_device(root_device).sycl_queue
1053-
np_ary = np.asarray(host_blob)
1054-
dt = np_ary.dtype
1055-
if dt.char in "dD" and q.sycl_device.has_aspect_fp64 is False:
1056-
Xusm_dtype = (
1057-
"float32" if dt.char == "d" else "complex64"
1058-
)
1059-
else:
1060-
Xusm_dtype = dt
1061-
usm_mem = dpmem.MemoryUSMDevice(np_ary.nbytes, queue=q)
1062-
usm_ary = usm_ndarray(np_ary.shape, dtype=Xusm_dtype, buffer=usm_mem)
1063-
usm_mem.copy_from_host(np.reshape(np_ary.view(dtype="u1"), -1))
1064-
return usm_ary
1082+
return _to_usm_ary_from_host_blob(host_blob, dl_device[1])

0 commit comments

Comments
 (0)