@@ -168,7 +168,7 @@ cdef void _managed_tensor_versioned_deleter(DLManagedTensorVersioned *dlmv_tenso
168
168
stdlib.free(dlmv_tensor)
169
169
170
170
171
- cdef object _get_default_context(c_dpctl.SyclDevice dev) except * :
171
+ cdef object _get_default_context(c_dpctl.SyclDevice dev):
172
172
try :
173
173
default_context = dev.sycl_platform.default_context
174
174
except RuntimeError :
@@ -178,7 +178,7 @@ cdef object _get_default_context(c_dpctl.SyclDevice dev) except *:
178
178
return default_context
179
179
180
180
181
- cdef int get_parent_device_ordinal_id(c_dpctl.SyclDevice dev) except * :
181
+ cdef int get_parent_device_ordinal_id(c_dpctl.SyclDevice dev) except - 1 :
182
182
cdef DPCTLSyclDeviceRef pDRef = NULL
183
183
cdef DPCTLSyclDeviceRef tDRef = NULL
184
184
cdef c_dpctl.SyclDevice p_dev
@@ -201,7 +201,7 @@ cdef int get_parent_device_ordinal_id(c_dpctl.SyclDevice dev) except *:
201
201
202
202
cdef int get_array_dlpack_device_id(
203
203
usm_ndarray usm_ary
204
- ) except * :
204
+ ) except - 1 :
205
205
""" Finds ordinal number of the parent of device where array
206
206
was allocated.
207
207
"""
@@ -1011,17 +1011,22 @@ def from_dlpack(x, /, *, device=None, copy=None):
1011
1011
1012
1012
"""
1013
1013
dlpack_attr = getattr (x, " __dlpack__" , None )
1014
- if not callable (dlpack_attr):
1014
+ dlpack_dev_attr = getattr (x, " __dlpack_device__" , None )
1015
+ if not callable (dlpack_attr) or not callable (dlpack_dev_attr):
1015
1016
raise TypeError (
1016
1017
f" The argument of type {type(x)} does not implement "
1017
- " `__dlpack__` method ."
1018
+ " `__dlpack__` and `__dlpack_device__` methods ."
1018
1019
)
1019
1020
try :
1020
1021
# device is converted to a dlpack_device if necessary
1021
1022
dl_device = None
1022
1023
if device:
1023
1024
if isinstance (device, tuple ):
1024
1025
dl_device = device
1026
+ if len (dl_device) != 2 :
1027
+ raise ValueError (
1028
+ " Argument `device` specified as a tuple must have length 2"
1029
+ )
1025
1030
else :
1026
1031
if not isinstance (device, dpctl.SyclDevice):
1027
1032
d = Device.create_device(device).sycl_device
@@ -1031,8 +1036,34 @@ def from_dlpack(x, /, *, device=None, copy=None):
1031
1036
dlpack_capsule = dlpack_attr(max_version = get_build_dlpack_version(), dl_device = dl_device, copy = copy)
1032
1037
return from_dlpack_capsule(dlpack_capsule)
1033
1038
except TypeError :
1034
- if (dl_device is None ) or (dl_device == x.__dlpack_device__()):
1039
+ x_dldev = dlpack_dev_attr()
1040
+ if (dl_device is None ) or (dl_device == x_dldev):
1035
1041
dlpack_capsule = dlpack_attr()
1036
1042
return from_dlpack_capsule(dlpack_capsule)
1043
+ # must copy via host
1044
+ if copy is False :
1045
+ raise ValueError (
1046
+ " Importing data via DLPack requires copying, but copy=False was provided"
1047
+ )
1048
+ if dl_device[0 ] != device_OneAPI:
1049
+ raise ValueError (f" Can not import to requested device {dl_device}" )
1050
+ if x_dldev == (device_CPU, 0 ):
1051
+ host_blob = x
1052
+ else :
1053
+ dlpack_capsule = dlpack_attr(max_version = (1 , 0 ), dl_device = (device_CPU, 0 ), copy = copy)
1054
+ host_blob = from_dlpack_capsule(dlpack_capsule)
1055
+ device_id = dl_device[1 ]
1056
+ root_device = dpctl.SyclDevice(str (< int > device_id))
1057
+ q = Device.create_device(root_device).sycl_queue
1058
+ np_ary = np.asarray(host_blob)
1059
+ dt = np_ary.dtype
1060
+ if dt.char in " dD" and q.sycl_device.has_aspect_fp64 is False :
1061
+ Xusm_dtype = (
1062
+ " float32" if dt.char == " d" else " complex64"
1063
+ )
1037
1064
else :
1038
- raise ValueError
1065
+ Xusm_dtype = dt
1066
+ usm_mem = dpmem.MemoryUSMDevice(np_ary.nbytes, queue = q)
1067
+ usm_ary = usm_ndarray(np_ary.shape, dtype = Xusm_dtype, buffer = usm_mem)
1068
+ usm_mem.copy_from_host(np.reshape(np_ary.view(dtype = " u1" ), - 1 ))
1069
+ return usm_ary
0 commit comments