@@ -935,9 +935,8 @@ cpdef object from_dlpack_capsule(object py_caps):
935
935
" The DLPack tensor resides on unsupported device."
936
936
)
937
937
938
- cdef usm_ndarray _to_usm_ary_from_host_blob(object host_blob, int device_id):
939
- root_device = dpctl.SyclDevice(str (< int > device_id))
940
- q = Device.create_device(root_device).sycl_queue
938
+ cdef usm_ndarray _to_usm_ary_from_host_blob(object host_blob, dev : Device):
939
+ q = dev.sycl_queue
941
940
np_ary = np.asarray(host_blob)
942
941
dt = np_ary.dtype
943
942
if dt.char in " dD" and q.sycl_device.has_aspect_fp64 is False :
@@ -952,14 +951,25 @@ cdef usm_ndarray _to_usm_ary_from_host_blob(object host_blob, int device_id):
952
951
return usm_ary
953
952
954
953
954
+ # only cdef to make it private
955
+ cdef object _create_device(object device, object dl_device):
956
+ if isinstance (device, Device):
957
+ return device
958
+ elif isinstance (device, dpctl.SyclDevice):
959
+ return Device.create_device(device)
960
+ else :
961
+ root_device = dpctl.SyclDevice(str (< int > dl_device[1 ]))
962
+ return Device.create_device(root_device)
963
+
964
+
955
965
def from_dlpack (x , /, *, device = None , copy = None ):
956
966
""" from_dlpack(x, /, *, device=None, copy=None)
957
967
958
968
Constructs :class:`dpctl.tensor.usm_ndarray` instance from a Python
959
969
object ``x`` that implements ``__dlpack__`` protocol.
960
970
961
971
Args:
962
- x (Python object):
972
+ x (object):
963
973
A Python object representing an array that supports
964
974
``__dlpack__`` protocol.
965
975
device (Optional[str,
@@ -1044,45 +1054,72 @@ def from_dlpack(x, /, *, device=None, copy=None):
1044
1054
)
1045
1055
else :
1046
1056
if not isinstance (device, dpctl.SyclDevice):
1047
- d = Device.create_device(device).sycl_device
1057
+ device = Device.create_device(device)
1058
+ d = device.sycl_device
1048
1059
else :
1049
1060
d = device
1050
1061
dl_device = (device_OneAPI, get_parent_device_ordinal_id(< c_dpctl.SyclDevice> d))
1062
+ if dl_device is not None :
1063
+ if (dl_device[0 ] not in [device_OneAPI, device_CPU]):
1064
+ raise ValueError (
1065
+ f" Argument `device`={device} is not supported."
1066
+ )
1067
+ got_type_error = False
1068
+ got_buffer_error = False
1069
+ got_other_error = False
1070
+ saved_exception = None
1051
1071
try :
1052
- dlpack_capsule = dlpack_attr( max_version = get_build_dlpack_version(), dl_device = dl_device, copy = copy)
1053
- return from_dlpack_capsule(dlpack_capsule )
1072
+ # setting max_version to minimal version that supports dl_device/ copy keywords
1073
+ dlpack_capsule = dlpack_attr( max_version = ( 1 , 0 ), dl_device = dl_device, copy = copy )
1054
1074
except TypeError :
1055
- # max_version/dl_device, copy keywords are not supported by __dlpack__
1056
- x_dldev = dlpack_dev_attr()
1057
- if (dl_device is None ) or (dl_device == x_dldev):
1058
- dlpack_capsule = dlpack_attr()
1059
- return from_dlpack_capsule(dlpack_capsule)
1060
- # must copy via host
1061
- if copy is False :
1062
- raise BufferError(
1063
- " Importing data via DLPack requires copying, but copy=False was provided"
1064
- )
1065
- if x_dldev == (device_CPU, 0 ) and dl_device[0 ] == device_OneAPI:
1066
- dlpack_capsule = dlpack_attr()
1067
- host_blob = from_dlpack_capsule(dlpack_capsule)
1068
- else :
1069
- raise BufferError(f" Can not import to requested device {dl_device}" )
1070
- return _to_usm_ary_from_host_blob(host_blob, dl_device[1 ])
1071
- except (BufferError, NotImplementedError ) as e:
1072
- # we are here, because dlpack_attr could not deal with requested dl_device,
1073
- # or copying was required
1074
- if copy is False :
1075
- raise BufferError(
1076
- " Importing data via DLPack requires copying, but copy=False was provided"
1077
- ) from e
1078
- # must copy via host
1079
- if dl_device[0 ] != device_OneAPI:
1080
- raise BufferError(f" Can not import to requested device {dl_device}" )
1081
- x_dldev = dlpack_dev_attr()
1082
- if x_dldev == (device_CPU, 0 ):
1083
- dlpack_capsule = dlpack_attr()
1084
- host_blob = from_dlpack_capsule(dlpack_capsule)
1085
- else :
1086
- dlpack_capsule = dlpack_attr(max_version = (1 , 0 ), dl_device = (device_CPU, 0 ), copy = copy)
1087
- host_blob = from_dlpack_capsule(dlpack_capsule)
1088
- return _to_usm_ary_from_host_blob(host_blob, dl_device[1 ])
1075
+ # exporter does not support max_version keyword
1076
+ got_type_error = True
1077
+ except (BufferError, NotImplementedError ):
1078
+ # Either dl_device, or copy can be satisfied
1079
+ got_buffer_error = True
1080
+ except Exception as e:
1081
+ got_other_error = True
1082
+ saved_exception = e
1083
+ else :
1084
+ # execution did not raise exceptions
1085
+ return from_dlpack_capsule(dlpack_capsule)
1086
+ finally :
1087
+ if got_type_error:
1088
+ # max_version/dl_device, copy keywords are not supported by __dlpack__
1089
+ x_dldev = dlpack_dev_attr()
1090
+ if (dl_device is None ) or (dl_device == x_dldev):
1091
+ dlpack_capsule = dlpack_attr()
1092
+ return from_dlpack_capsule(dlpack_capsule)
1093
+ # must copy via host
1094
+ if copy is False :
1095
+ raise BufferError(
1096
+ " Importing data via DLPack requires copying, but copy=False was provided"
1097
+ )
1098
+ if x_dldev == (device_CPU, 0 ) and dl_device[0 ] == device_OneAPI:
1099
+ dlpack_capsule = dlpack_attr()
1100
+ host_blob = from_dlpack_capsule(dlpack_capsule)
1101
+ else :
1102
+ raise BufferError(f" Can not import to requested device {dl_device}" )
1103
+ dev = _create_device(device, dl_device)
1104
+ return _to_usm_ary_from_host_blob(host_blob, dev)
1105
+ elif got_buffer_error:
1106
+ # we are here, because dlpack_attr could not deal with requested dl_device,
1107
+ # or copying was required
1108
+ if copy is False :
1109
+ raise BufferError(
1110
+ " Importing data via DLPack requires copying, but copy=False was provided"
1111
+ )
1112
+ # must copy via host
1113
+ if dl_device[0 ] != device_OneAPI:
1114
+ raise BufferError(f" Can not import to requested device {dl_device}" )
1115
+ x_dldev = dlpack_dev_attr()
1116
+ if x_dldev == (device_CPU, 0 ):
1117
+ dlpack_capsule = dlpack_attr()
1118
+ host_blob = from_dlpack_capsule(dlpack_capsule)
1119
+ else :
1120
+ dlpack_capsule = dlpack_attr(max_version = (1 , 0 ), dl_device = (device_CPU, 0 ), copy = copy)
1121
+ host_blob = from_dlpack_capsule(dlpack_capsule)
1122
+ dev = _create_device(device, dl_device)
1123
+ return _to_usm_ary_from_host_blob(host_blob, dev)
1124
+ elif got_other_error:
1125
+ raise saved_exception
0 commit comments