Skip to content

Commit 67ab1bb

Browse files
Expand applicability of fall-back behavior
When `from_dlpack(arr, device=dev)` is called, for `arr` object that supports legacy DLPack interface (max_version, dl_device, copy are not supported), we now support arr being device on host, that is (kDLCPU, 0), and (kDLOneAPI, different_device_id). Support for this last case is being added in this commit, as per review comment.
1 parent 25f8afb commit 67ab1bb

File tree

1 file changed

+20
-3
lines changed

1 file changed

+20
-3
lines changed

dpctl/tensor/_dlpack.pyx

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1068,9 +1068,15 @@ def from_dlpack(x, /, *, device=None, copy=None):
10681068
got_buffer_error = False
10691069
got_other_error = False
10701070
saved_exception = None
1071+
# First DLPack version supporting dl_device, and copy
1072+
requested_ver = (1, 0)
10711073
try:
10721074
# setting max_version to minimal version that supports dl_device/copy keywords
1073-
dlpack_capsule = dlpack_attr(max_version=(1, 0), dl_device=dl_device, copy=copy)
1075+
dlpack_capsule = dlpack_attr(
1076+
max_version=requested_ver,
1077+
dl_device=dl_device,
1078+
copy=copy
1079+
)
10741080
except TypeError:
10751081
# exporter does not support max_version keyword
10761082
got_type_error = True
@@ -1095,7 +1101,14 @@ def from_dlpack(x, /, *, device=None, copy=None):
10951101
raise BufferError(
10961102
"Importing data via DLPack requires copying, but copy=False was provided"
10971103
)
1098-
if x_dldev == (device_CPU, 0) and dl_device[0] == device_OneAPI:
1104+
# when max_version/dl_device/copy are not supported
1105+
# we can only support importing to OneAPI devices
1106+
# from host, or from another oneAPI device
1107+
is_supported_x_dldev = (
1108+
x_dldev == (device_CPU, 0) or
1109+
(x_dldev[0] == device_OneAPI)
1110+
)
1111+
if is_supported_x_dldev and dl_device[0] == device_OneAPI:
10991112
dlpack_capsule = dlpack_attr()
11001113
host_blob = from_dlpack_capsule(dlpack_capsule)
11011114
else:
@@ -1117,7 +1130,11 @@ def from_dlpack(x, /, *, device=None, copy=None):
11171130
dlpack_capsule = dlpack_attr()
11181131
host_blob = from_dlpack_capsule(dlpack_capsule)
11191132
else:
1120-
dlpack_capsule = dlpack_attr(max_version=(1, 0), dl_device=(device_CPU, 0), copy=copy)
1133+
dlpack_capsule = dlpack_attr(
1134+
max_version=requested_ver,
1135+
dl_device=(device_CPU, 0),
1136+
copy=copy
1137+
)
11211138
host_blob = from_dlpack_capsule(dlpack_capsule)
11221139
dev = _create_device(device, dl_device)
11231140
return _to_usm_ary_from_host_blob(host_blob, dev)

0 commit comments

Comments
 (0)