Skip to content

Commit 03e9001

Browse files
Refined from_dlpack docstrings, reorged impl of from_dlpack
Used try/except/else/finally to avoid raising an exception when another one is in flight (confusing UX). device keyword is only allowed to be (kDLCPU, 0) or (kDLOneAPI, num). Device keyword value is used to create output array, rather than device_id deduced from it.
1 parent c621707 commit 03e9001

File tree

1 file changed

+78
-41
lines changed

1 file changed

+78
-41
lines changed

dpctl/tensor/_dlpack.pyx

Lines changed: 78 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -935,9 +935,8 @@ cpdef object from_dlpack_capsule(object py_caps):
935935
"The DLPack tensor resides on unsupported device."
936936
)
937937

938-
cdef usm_ndarray _to_usm_ary_from_host_blob(object host_blob, int device_id):
939-
root_device = dpctl.SyclDevice(str(<int>device_id))
940-
q = Device.create_device(root_device).sycl_queue
938+
cdef usm_ndarray _to_usm_ary_from_host_blob(object host_blob, dev : Device):
939+
q = dev.sycl_queue
941940
np_ary = np.asarray(host_blob)
942941
dt = np_ary.dtype
943942
if dt.char in "dD" and q.sycl_device.has_aspect_fp64 is False:
@@ -952,14 +951,25 @@ cdef usm_ndarray _to_usm_ary_from_host_blob(object host_blob, int device_id):
952951
return usm_ary
953952

954953

954+
# only cdef to make it private
955+
cdef object _create_device(object device, object dl_device):
956+
if isinstance(device, Device):
957+
return device
958+
elif isinstance(device, dpctl.SyclDevice):
959+
return Device.create_device(device)
960+
else:
961+
root_device = dpctl.SyclDevice(str(<int>dl_device[1]))
962+
return Device.create_device(root_device)
963+
964+
955965
def from_dlpack(x, /, *, device=None, copy=None):
956966
""" from_dlpack(x, /, *, device=None, copy=None)
957967
958968
Constructs :class:`dpctl.tensor.usm_ndarray` instance from a Python
959969
object ``x`` that implements ``__dlpack__`` protocol.
960970
961971
Args:
962-
x (Python object):
972+
x (object):
963973
A Python object representing an array that supports
964974
``__dlpack__`` protocol.
965975
device (Optional[str,
@@ -1044,45 +1054,72 @@ def from_dlpack(x, /, *, device=None, copy=None):
10441054
)
10451055
else:
10461056
if not isinstance(device, dpctl.SyclDevice):
1047-
d = Device.create_device(device).sycl_device
1057+
device = Device.create_device(device)
1058+
d = device.sycl_device
10481059
else:
10491060
d = device
10501061
dl_device = (device_OneAPI, get_parent_device_ordinal_id(<c_dpctl.SyclDevice>d))
1062+
if dl_device is not None:
1063+
if (dl_device[0] not in [device_OneAPI, device_CPU]):
1064+
raise ValueError(
1065+
f"Argument `device`={device} is not supported."
1066+
)
1067+
got_type_error = False
1068+
got_buffer_error = False
1069+
got_other_error = False
1070+
saved_exception = None
10511071
try:
1052-
dlpack_capsule = dlpack_attr(max_version=get_build_dlpack_version(), dl_device=dl_device, copy=copy)
1053-
return from_dlpack_capsule(dlpack_capsule)
1072+
# setting max_version to minimal version that supports dl_device/copy keywords
1073+
dlpack_capsule = dlpack_attr(max_version=(1, 0), dl_device=dl_device, copy=copy)
10541074
except TypeError:
1055-
# max_version/dl_device, copy keywords are not supported by __dlpack__
1056-
x_dldev = dlpack_dev_attr()
1057-
if (dl_device is None) or (dl_device == x_dldev):
1058-
dlpack_capsule = dlpack_attr()
1059-
return from_dlpack_capsule(dlpack_capsule)
1060-
# must copy via host
1061-
if copy is False:
1062-
raise BufferError(
1063-
"Importing data via DLPack requires copying, but copy=False was provided"
1064-
)
1065-
if x_dldev == (device_CPU, 0) and dl_device[0] == device_OneAPI:
1066-
dlpack_capsule = dlpack_attr()
1067-
host_blob = from_dlpack_capsule(dlpack_capsule)
1068-
else:
1069-
raise BufferError(f"Can not import to requested device {dl_device}")
1070-
return _to_usm_ary_from_host_blob(host_blob, dl_device[1])
1071-
except (BufferError, NotImplementedError) as e:
1072-
# we are here, because dlpack_attr could not deal with requested dl_device,
1073-
# or copying was required
1074-
if copy is False:
1075-
raise BufferError(
1076-
"Importing data via DLPack requires copying, but copy=False was provided"
1077-
) from e
1078-
# must copy via host
1079-
if dl_device[0] != device_OneAPI:
1080-
raise BufferError(f"Can not import to requested device {dl_device}")
1081-
x_dldev = dlpack_dev_attr()
1082-
if x_dldev == (device_CPU, 0):
1083-
dlpack_capsule = dlpack_attr()
1084-
host_blob = from_dlpack_capsule(dlpack_capsule)
1085-
else:
1086-
dlpack_capsule = dlpack_attr(max_version=(1, 0), dl_device=(device_CPU, 0), copy=copy)
1087-
host_blob = from_dlpack_capsule(dlpack_capsule)
1088-
return _to_usm_ary_from_host_blob(host_blob, dl_device[1])
1075+
# exporter does not support max_version keyword
1076+
got_type_error = True
1077+
except (BufferError, NotImplementedError):
1078+
# Either dl_device, or copy can be satisfied
1079+
got_buffer_error = True
1080+
except Exception as e:
1081+
got_other_error = True
1082+
saved_exception = e
1083+
else:
1084+
# execution did not raise exceptions
1085+
return from_dlpack_capsule(dlpack_capsule)
1086+
finally:
1087+
if got_type_error:
1088+
# max_version/dl_device, copy keywords are not supported by __dlpack__
1089+
x_dldev = dlpack_dev_attr()
1090+
if (dl_device is None) or (dl_device == x_dldev):
1091+
dlpack_capsule = dlpack_attr()
1092+
return from_dlpack_capsule(dlpack_capsule)
1093+
# must copy via host
1094+
if copy is False:
1095+
raise BufferError(
1096+
"Importing data via DLPack requires copying, but copy=False was provided"
1097+
)
1098+
if x_dldev == (device_CPU, 0) and dl_device[0] == device_OneAPI:
1099+
dlpack_capsule = dlpack_attr()
1100+
host_blob = from_dlpack_capsule(dlpack_capsule)
1101+
else:
1102+
raise BufferError(f"Can not import to requested device {dl_device}")
1103+
dev = _create_device(device, dl_device)
1104+
return _to_usm_ary_from_host_blob(host_blob, dev)
1105+
elif got_buffer_error:
1106+
# we are here, because dlpack_attr could not deal with requested dl_device,
1107+
# or copying was required
1108+
if copy is False:
1109+
raise BufferError(
1110+
"Importing data via DLPack requires copying, but copy=False was provided"
1111+
)
1112+
# must copy via host
1113+
if dl_device[0] != device_OneAPI:
1114+
raise BufferError(f"Can not import to requested device {dl_device}")
1115+
x_dldev = dlpack_dev_attr()
1116+
if x_dldev == (device_CPU, 0):
1117+
dlpack_capsule = dlpack_attr()
1118+
host_blob = from_dlpack_capsule(dlpack_capsule)
1119+
else:
1120+
dlpack_capsule = dlpack_attr(max_version=(1, 0), dl_device=(device_CPU, 0), copy=copy)
1121+
host_blob = from_dlpack_capsule(dlpack_capsule)
1122+
dev = _create_device(device, dl_device)
1123+
return _to_usm_ary_from_host_blob(host_blob, dev)
1124+
elif got_other_error:
1125+
raise saved_exception

0 commit comments

Comments
 (0)