Skip to content

Commit e87c669

Browse files
To ensure same validation across branches, compute host_blob by roundtripping it through dlpack
1 parent 32196f5 commit e87c669

File tree

1 file changed

+19
-17
lines changed

1 file changed

+19
-17
lines changed

dpctl/tensor/_dlpack.pyx

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1033,22 +1033,22 @@ def from_dlpack(x, /, *, device=None, copy=None):
10331033
f"The argument of type {type(x)} does not implement "
10341034
"`__dlpack__` and `__dlpack_device__` methods."
10351035
)
1036-
try:
1037-
# device is converted to a dlpack_device if necessary
1038-
dl_device = None
1039-
if device:
1040-
if isinstance(device, tuple):
1041-
dl_device = device
1042-
if len(dl_device) != 2:
1043-
raise ValueError(
1044-
"Argument `device` specified as a tuple must have length 2"
1045-
)
1036+
# device is converted to a dlpack_device if necessary
1037+
dl_device = None
1038+
if device:
1039+
if isinstance(device, tuple):
1040+
dl_device = device
1041+
if len(dl_device) != 2:
1042+
raise ValueError(
1043+
"Argument `device` specified as a tuple must have length 2"
1044+
)
1045+
else:
1046+
if not isinstance(device, dpctl.SyclDevice):
1047+
d = Device.create_device(device).sycl_device
10461048
else:
1047-
if not isinstance(device, dpctl.SyclDevice):
1048-
d = Device.create_device(device).sycl_device
1049-
else:
1050-
d = device
1051-
dl_device = (device_OneAPI, get_parent_device_ordinal_id(<c_dpctl.SyclDevice>d))
1049+
d = device
1050+
dl_device = (device_OneAPI, get_parent_device_ordinal_id(<c_dpctl.SyclDevice>d))
1051+
try:
10521052
dlpack_capsule = dlpack_attr(max_version=get_build_dlpack_version(), dl_device=dl_device, copy=copy)
10531053
return from_dlpack_capsule(dlpack_capsule)
10541054
except TypeError:
@@ -1063,7 +1063,8 @@ def from_dlpack(x, /, *, device=None, copy=None):
10631063
"Importing data via DLPack requires copying, but copy=False was provided"
10641064
)
10651065
if x_dldev == (device_CPU, 0) and dl_device[0] == device_OneAPI:
1066-
host_blob = x
1066+
dlpack_capsule = dlpack_attr()
1067+
host_blob = from_dlpack_capsule(dlpack_capsule)
10671068
else:
10681069
raise BufferError(f"Can not import to requested device {dl_device}")
10691070
return _to_usm_ary_from_host_blob(host_blob, dl_device[1])
@@ -1079,7 +1080,8 @@ def from_dlpack(x, /, *, device=None, copy=None):
10791080
raise BufferError(f"Can not import to requested device {dl_device}")
10801081
x_dldev = dlpack_dev_attr()
10811082
if x_dldev == (device_CPU, 0):
1082-
host_blob = x
1083+
dlpack_capsule = dlpack_attr()
1084+
host_blob = from_dlpack_capsule(dlpack_capsule)
10831085
else:
10841086
dlpack_capsule = dlpack_attr(max_version=(1, 0), dl_device=(device_CPU, 0), copy=copy)
10851087
host_blob = from_dlpack_capsule(dlpack_capsule)

0 commit comments

Comments
 (0)