@@ -930,6 +930,22 @@ cpdef object from_dlpack_capsule(object py_caps):
930
930
" The DLPack tensor resides on unsupported device."
931
931
)
932
932
933
+ cdef usm_ndarray _to_usm_ary_from_host_blob(object host_blob, int device_id):
934
+ root_device = dpctl.SyclDevice(str (< int > device_id))
935
+ q = Device.create_device(root_device).sycl_queue
936
+ np_ary = np.asarray(host_blob)
937
+ dt = np_ary.dtype
938
+ if dt.char in " dD" and q.sycl_device.has_aspect_fp64 is False :
939
+ Xusm_dtype = (
940
+ " float32" if dt.char == " d" else " complex64"
941
+ )
942
+ else :
943
+ Xusm_dtype = dt
944
+ usm_mem = dpmem.MemoryUSMDevice(np_ary.nbytes, queue = q)
945
+ usm_ary = usm_ndarray(np_ary.shape, dtype = Xusm_dtype, buffer = usm_mem)
946
+ usm_mem.copy_from_host(np.reshape(np_ary.view(dtype = " u1" ), - 1 ))
947
+ return usm_ary
948
+
933
949
934
950
def from_dlpack (x , /, *, device = None , copy = None ):
935
951
""" from_dlpack(x, /, *, device=None, copy=None)
@@ -1031,34 +1047,36 @@ def from_dlpack(x, /, *, device=None, copy=None):
1031
1047
dlpack_capsule = dlpack_attr(max_version = get_build_dlpack_version(), dl_device = dl_device, copy = copy)
1032
1048
return from_dlpack_capsule(dlpack_capsule)
1033
1049
except TypeError :
1050
+ # max_version/dl_device, copy keywords are not supported by __dlpack__
1034
1051
x_dldev = dlpack_dev_attr()
1035
1052
if (dl_device is None ) or (dl_device == x_dldev):
1036
1053
dlpack_capsule = dlpack_attr()
1037
1054
return from_dlpack_capsule(dlpack_capsule)
1038
1055
# must copy via host
1039
1056
if copy is False :
1040
- raise ValueError (
1057
+ raise BufferError (
1041
1058
" Importing data via DLPack requires copying, but copy=False was provided"
1042
1059
)
1060
+ if x_dldev == (device_CPU, 0 ) and dl_device[0 ] == device_OneAPI:
1061
+ host_blob = x
1062
+ else :
1063
+ raise BufferError(f" Can not import to requested device {dl_device}" )
1064
+ return _to_usm_ary_from_host_blob(host_blob, dl_device[1 ])
1065
+ except BufferError as e:
1066
+ # we are here, because dlpack_attr could not deal with requested dl_device,
1067
+ # or copying was required
1068
+ if copy is False :
1069
+ raise BufferError(
1070
+ " Importing data via DLPack requires copying, but copy=False was provided"
1071
+ ) from e
1072
+ # must copy via host
1043
1073
if dl_device[0 ] != device_OneAPI:
1044
- raise ValueError (f" Can not import to requested device {dl_device}" )
1074
+ raise BufferError(f" Can not import to requested device {dl_device}" )
1075
+ x_dldev = dlpack_dev_attr()
1045
1076
if x_dldev == (device_CPU, 0 ):
1046
1077
host_blob = x
1047
1078
else :
1079
+ # this would fail anyway
1048
1080
dlpack_capsule = dlpack_attr(max_version = (1 , 0 ), dl_device = (device_CPU, 0 ), copy = copy)
1049
1081
host_blob = from_dlpack_capsule(dlpack_capsule)
1050
- device_id = dl_device[1 ]
1051
- root_device = dpctl.SyclDevice(str (< int > device_id))
1052
- q = Device.create_device(root_device).sycl_queue
1053
- np_ary = np.asarray(host_blob)
1054
- dt = np_ary.dtype
1055
- if dt.char in " dD" and q.sycl_device.has_aspect_fp64 is False :
1056
- Xusm_dtype = (
1057
- " float32" if dt.char == " d" else " complex64"
1058
- )
1059
- else :
1060
- Xusm_dtype = dt
1061
- usm_mem = dpmem.MemoryUSMDevice(np_ary.nbytes, queue = q)
1062
- usm_ary = usm_ndarray(np_ary.shape, dtype = Xusm_dtype, buffer = usm_mem)
1063
- usm_mem.copy_from_host(np.reshape(np_ary.view(dtype = " u1" ), - 1 ))
1064
- return usm_ary
1082
+ return _to_usm_ary_from_host_blob(host_blob, dl_device[1 ])
0 commit comments