@@ -1033,22 +1033,22 @@ def from_dlpack(x, /, *, device=None, copy=None):
1033
1033
f" The argument of type {type(x)} does not implement "
1034
1034
" `__dlpack__` and `__dlpack_device__` methods."
1035
1035
)
1036
- try :
1037
- # device is converted to a dlpack_device if necessary
1038
- dl_device = None
1039
- if device:
1040
- if isinstance (device, tuple ):
1041
- dl_device = device
1042
- if len (dl_device) != 2 :
1043
- raise ValueError (
1044
- " Argument `device` specified as a tuple must have length 2"
1045
- )
1036
+ # device is converted to a dlpack_device if necessary
1037
+ dl_device = None
1038
+ if device:
1039
+ if isinstance (device, tuple ):
1040
+ dl_device = device
1041
+ if len (dl_device) != 2 :
1042
+ raise ValueError (
1043
+ " Argument `device` specified as a tuple must have length 2"
1044
+ )
1045
+ else :
1046
+ if not isinstance (device, dpctl.SyclDevice):
1047
+ d = Device.create_device(device).sycl_device
1046
1048
else :
1047
- if not isinstance (device, dpctl.SyclDevice):
1048
- d = Device.create_device(device).sycl_device
1049
- else :
1050
- d = device
1051
- dl_device = (device_OneAPI, get_parent_device_ordinal_id(< c_dpctl.SyclDevice> d))
1049
+ d = device
1050
+ dl_device = (device_OneAPI, get_parent_device_ordinal_id(< c_dpctl.SyclDevice> d))
1051
+ try :
1052
1052
dlpack_capsule = dlpack_attr(max_version = get_build_dlpack_version(), dl_device = dl_device, copy = copy)
1053
1053
return from_dlpack_capsule(dlpack_capsule)
1054
1054
except TypeError :
@@ -1063,7 +1063,8 @@ def from_dlpack(x, /, *, device=None, copy=None):
1063
1063
" Importing data via DLPack requires copying, but copy=False was provided"
1064
1064
)
1065
1065
if x_dldev == (device_CPU, 0 ) and dl_device[0 ] == device_OneAPI:
1066
- host_blob = x
1066
+ dlpack_capsule = dlpack_attr()
1067
+ host_blob = from_dlpack_capsule(dlpack_capsule)
1067
1068
else :
1068
1069
raise BufferError(f" Can not import to requested device {dl_device}" )
1069
1070
return _to_usm_ary_from_host_blob(host_blob, dl_device[1 ])
@@ -1079,7 +1080,8 @@ def from_dlpack(x, /, *, device=None, copy=None):
1079
1080
raise BufferError(f" Can not import to requested device {dl_device}" )
1080
1081
x_dldev = dlpack_dev_attr()
1081
1082
if x_dldev == (device_CPU, 0 ):
1082
- host_blob = x
1083
+ dlpack_capsule = dlpack_attr()
1084
+ host_blob = from_dlpack_capsule(dlpack_capsule)
1083
1085
else :
1084
1086
dlpack_capsule = dlpack_attr(max_version = (1 , 0 ), dl_device = (device_CPU, 0 ), copy = copy)
1085
1087
host_blob = from_dlpack_capsule(dlpack_capsule)
0 commit comments