Skip to content

Commit e01f33b

Browse files
Changed from_dlpack to copy via host is needed
This enables dpt.from_dlpack(numpy_array, device="opencl:cpu")
1 parent 2946b40 commit e01f33b

File tree

1 file changed

+38
-7
lines changed

1 file changed

+38
-7
lines changed

dpctl/tensor/_dlpack.pyx

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ cdef void _managed_tensor_versioned_deleter(DLManagedTensorVersioned *dlmv_tenso
168168
stdlib.free(dlmv_tensor)
169169

170170

171-
cdef object _get_default_context(c_dpctl.SyclDevice dev) except *:
171+
cdef object _get_default_context(c_dpctl.SyclDevice dev):
172172
try:
173173
default_context = dev.sycl_platform.default_context
174174
except RuntimeError:
@@ -178,7 +178,7 @@ cdef object _get_default_context(c_dpctl.SyclDevice dev) except *:
178178
return default_context
179179

180180

181-
cdef int get_parent_device_ordinal_id(c_dpctl.SyclDevice dev) except *:
181+
cdef int get_parent_device_ordinal_id(c_dpctl.SyclDevice dev) except -1:
182182
cdef DPCTLSyclDeviceRef pDRef = NULL
183183
cdef DPCTLSyclDeviceRef tDRef = NULL
184184
cdef c_dpctl.SyclDevice p_dev
@@ -201,7 +201,7 @@ cdef int get_parent_device_ordinal_id(c_dpctl.SyclDevice dev) except *:
201201

202202
cdef int get_array_dlpack_device_id(
203203
usm_ndarray usm_ary
204-
) except *:
204+
) except -1:
205205
"""Finds ordinal number of the parent of device where array
206206
was allocated.
207207
"""
@@ -1011,17 +1011,22 @@ def from_dlpack(x, /, *, device=None, copy=None):
10111011
10121012
"""
10131013
dlpack_attr = getattr(x, "__dlpack__", None)
1014-
if not callable(dlpack_attr):
1014+
dlpack_dev_attr = getattr(x, "__dlpack_device__", None)
1015+
if not callable(dlpack_attr) or not callable(dlpack_dev_attr):
10151016
raise TypeError(
10161017
f"The argument of type {type(x)} does not implement "
1017-
"`__dlpack__` method."
1018+
"`__dlpack__` and `__dlpack_device__` methods."
10181019
)
10191020
try:
10201021
# device is converted to a dlpack_device if necessary
10211022
dl_device = None
10221023
if device:
10231024
if isinstance(device, tuple):
10241025
dl_device = device
1026+
if len(dl_device) != 2:
1027+
raise ValueError(
1028+
"Argument `device` specified as a tuple must have length 2"
1029+
)
10251030
else:
10261031
if not isinstance(device, dpctl.SyclDevice):
10271032
d = Device.create_device(device).sycl_device
@@ -1031,8 +1036,34 @@ def from_dlpack(x, /, *, device=None, copy=None):
10311036
dlpack_capsule = dlpack_attr(max_version=get_build_dlpack_version(), dl_device=dl_device, copy=copy)
10321037
return from_dlpack_capsule(dlpack_capsule)
10331038
except TypeError:
1034-
if (dl_device is None) or (dl_device == x.__dlpack_device__()):
1039+
x_dldev = dlpack_dev_attr()
1040+
if (dl_device is None) or (dl_device == x_dldev):
10351041
dlpack_capsule = dlpack_attr()
10361042
return from_dlpack_capsule(dlpack_capsule)
1043+
# must copy via host
1044+
if copy is False:
1045+
raise ValueError(
1046+
"Importing data via DLPack requires copying, but copy=False was provided"
1047+
)
1048+
if dl_device[0] != device_OneAPI:
1049+
raise ValueError(f"Can not import to requested device {dl_device}")
1050+
if x_dldev == (device_CPU, 0):
1051+
host_blob = x
1052+
else:
1053+
dlpack_capsule = dlpack_attr(max_version=(1, 0), dl_device=(device_CPU, 0), copy=copy)
1054+
host_blob = from_dlpack_capsule(dlpack_capsule)
1055+
device_id = dl_device[1]
1056+
root_device = dpctl.SyclDevice(str(<int>device_id))
1057+
q = Device.create_device(root_device).sycl_queue
1058+
np_ary = np.asarray(host_blob)
1059+
dt = np_ary.dtype
1060+
if dt.char in "dD" and q.sycl_device.has_aspect_fp64 is False:
1061+
Xusm_dtype = (
1062+
"float32" if dt.char == "d" else "complex64"
1063+
)
10371064
else:
1038-
raise ValueError
1065+
Xusm_dtype = dt
1066+
usm_mem = dpmem.MemoryUSMDevice(np_ary.nbytes, queue=q)
1067+
usm_ary = usm_ndarray(np_ary.shape, dtype=Xusm_dtype, buffer=usm_mem)
1068+
usm_mem.copy_from_host(np.reshape(np_ary.view(dtype="u1"), -1))
1069+
return usm_ary

0 commit comments

Comments
 (0)