@@ -935,6 +935,22 @@ cpdef object from_dlpack_capsule(object py_caps):
935
935
" The DLPack tensor resides on unsupported device."
936
936
)
937
937
938
+ cdef usm_ndarray _to_usm_ary_from_host_blob(object host_blob, int device_id):
939
+ root_device = dpctl.SyclDevice(str (< int > device_id))
940
+ q = Device.create_device(root_device).sycl_queue
941
+ np_ary = np.asarray(host_blob)
942
+ dt = np_ary.dtype
943
+ if dt.char in " dD" and q.sycl_device.has_aspect_fp64 is False :
944
+ Xusm_dtype = (
945
+ " float32" if dt.char == " d" else " complex64"
946
+ )
947
+ else :
948
+ Xusm_dtype = dt
949
+ usm_mem = dpmem.MemoryUSMDevice(np_ary.nbytes, queue = q)
950
+ usm_ary = usm_ndarray(np_ary.shape, dtype = Xusm_dtype, buffer = usm_mem)
951
+ usm_mem.copy_from_host(np.reshape(np_ary.view(dtype = " u1" ), - 1 ))
952
+ return usm_ary
953
+
938
954
939
955
def from_dlpack (x , /, *, device = None , copy = None ):
940
956
""" from_dlpack(x, /, *, device=None, copy=None)
@@ -1036,34 +1052,36 @@ def from_dlpack(x, /, *, device=None, copy=None):
1036
1052
dlpack_capsule = dlpack_attr(max_version = get_build_dlpack_version(), dl_device = dl_device, copy = copy)
1037
1053
return from_dlpack_capsule(dlpack_capsule)
1038
1054
except TypeError :
1055
+ # max_version/dl_device, copy keywords are not supported by __dlpack__
1039
1056
x_dldev = dlpack_dev_attr()
1040
1057
if (dl_device is None ) or (dl_device == x_dldev):
1041
1058
dlpack_capsule = dlpack_attr()
1042
1059
return from_dlpack_capsule(dlpack_capsule)
1043
1060
# must copy via host
1044
1061
if copy is False :
1045
- raise ValueError (
1062
+ raise BufferError (
1046
1063
" Importing data via DLPack requires copying, but copy=False was provided"
1047
1064
)
1065
+ if x_dldev == (device_CPU, 0 ) and dl_device[0 ] == device_OneAPI:
1066
+ host_blob = x
1067
+ else :
1068
+ raise BufferError(f" Can not import to requested device {dl_device}" )
1069
+ return _to_usm_ary_from_host_blob(host_blob, dl_device[1 ])
1070
+ except BufferError as e:
1071
+ # we are here, because dlpack_attr could not deal with requested dl_device,
1072
+ # or copying was required
1073
+ if copy is False :
1074
+ raise BufferError(
1075
+ " Importing data via DLPack requires copying, but copy=False was provided"
1076
+ ) from e
1077
+ # must copy via host
1048
1078
if dl_device[0 ] != device_OneAPI:
1049
- raise ValueError (f" Can not import to requested device {dl_device}" )
1079
+ raise BufferError(f" Can not import to requested device {dl_device}" )
1080
+ x_dldev = dlpack_dev_attr()
1050
1081
if x_dldev == (device_CPU, 0 ):
1051
1082
host_blob = x
1052
1083
else :
1084
+ # this would fail anyway
1053
1085
dlpack_capsule = dlpack_attr(max_version = (1 , 0 ), dl_device = (device_CPU, 0 ), copy = copy)
1054
1086
host_blob = from_dlpack_capsule(dlpack_capsule)
1055
- device_id = dl_device[1 ]
1056
- root_device = dpctl.SyclDevice(str (< int > device_id))
1057
- q = Device.create_device(root_device).sycl_queue
1058
- np_ary = np.asarray(host_blob)
1059
- dt = np_ary.dtype
1060
- if dt.char in " dD" and q.sycl_device.has_aspect_fp64 is False :
1061
- Xusm_dtype = (
1062
- " float32" if dt.char == " d" else " complex64"
1063
- )
1064
- else :
1065
- Xusm_dtype = dt
1066
- usm_mem = dpmem.MemoryUSMDevice(np_ary.nbytes, queue = q)
1067
- usm_ary = usm_ndarray(np_ary.shape, dtype = Xusm_dtype, buffer = usm_mem)
1068
- usm_mem.copy_from_host(np.reshape(np_ary.view(dtype = " u1" ), - 1 ))
1069
- return usm_ary
1087
+ return _to_usm_ary_from_host_blob(host_blob, dl_device[1 ])
0 commit comments