Skip to content

Commit 424e4c8

Browse files
Add symmetric support for containers with legacy DLPack support
For legacy containers, support device=(kDLCPU, 0) as well as oneAPI device.
1 parent 67ab1bb commit 424e4c8

File tree

1 file changed

+29
-7
lines changed

1 file changed

+29
-7
lines changed

dpctl/tensor/_dlpack.pyx

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -985,7 +985,8 @@ def from_dlpack(x, /, *, device=None, copy=None):
985985
returned by :attr:`dpctl.tensor.usm_ndarray.device`, or a
986986
2-tuple matching the format of the output of the ``__dlpack_device__``
987987
method, an integer enumerator representing the device type followed by
988-
an integer representing the index of the device.
988+
an integer representing the index of the device. The only supported
989+
:enum:`dpctl.tensor.DLDeviceType` types are "kDLCPU" and "kDLOneAPI".
989990
Default: ``None``.
990991
copy (bool, optional)
991992
Boolean indicating whether or not to copy the input.
@@ -1034,6 +1035,7 @@ def from_dlpack(x, /, *, device=None, copy=None):
10341035
10351036
C = Container(dpt.linspace(0, 100, num=20, dtype="int16"))
10361037
X = dpt.from_dlpack(C)
1038+
Y = dpt.from_dlpack(C, device=(dpt.DLDeviceType.kDLCPU, 0))
10371039
10381040
"""
10391041
dlpack_attr = getattr(x, "__dlpack__", None)
@@ -1070,6 +1072,7 @@ def from_dlpack(x, /, *, device=None, copy=None):
10701072
saved_exception = None
10711073
# First DLPack version supporting dl_device, and copy
10721074
requested_ver = (1, 0)
1075+
cpu_dev = (device_CPU, 0)
10731076
try:
10741077
# setting max_version to minimal version that supports dl_device/copy keywords
10751078
dlpack_capsule = dlpack_attr(
@@ -1105,16 +1108,35 @@ def from_dlpack(x, /, *, device=None, copy=None):
11051108
# we can only support importing to OneAPI devices
11061109
# from host, or from another oneAPI device
11071110
is_supported_x_dldev = (
1108-
x_dldev == (device_CPU, 0) or
1111+
x_dldev == cpu_dev or
11091112
(x_dldev[0] == device_OneAPI)
11101113
)
1111-
if is_supported_x_dldev and dl_device[0] == device_OneAPI:
1114+
is_supported_dl_device = (
1115+
dl_device == cpu_dev or
1116+
dl_device[0] == device_OneAPI
1117+
)
1118+
if is_supported_x_dldev and is_supported_dl_device:
11121119
dlpack_capsule = dlpack_attr()
1113-
host_blob = from_dlpack_capsule(dlpack_capsule)
1120+
blob = from_dlpack_capsule(dlpack_capsule)
11141121
else:
11151122
raise BufferError(f"Can not import to requested device {dl_device}")
11161123
dev = _create_device(device, dl_device)
1117-
return _to_usm_ary_from_host_blob(host_blob, dev)
1124+
if x_dldev == cpu_dev and dl_device == cpu_dev:
1125+
# both source and destion are CPU
1126+
return blob
1127+
elif x_dldev == cpu_dev:
1128+
# source is CPU, destingation is oneAPI
1129+
return _to_usm_ary_from_host_blob(blob, dev)
1130+
elif dl_device == cpu_dev:
1131+
# source is oneAPI, destination is CPU
1132+
cpu_caps = blob.__dlpack__(
1133+
max_version=get_build_dlpack_version(),
1134+
dl_device=cpu_dev
1135+
)
1136+
return from_dlpack_capsule(cpu_caps)
1137+
else:
1138+
import dpctl.tensor as dpt
1139+
return dpt.asarray(blob, device=dev)
11181140
elif got_buffer_error:
11191141
# we are here, because dlpack_attr could not deal with requested dl_device,
11201142
# or copying was required
@@ -1126,13 +1148,13 @@ def from_dlpack(x, /, *, device=None, copy=None):
11261148
if dl_device[0] != device_OneAPI:
11271149
raise BufferError(f"Can not import to requested device {dl_device}")
11281150
x_dldev = dlpack_dev_attr()
1129-
if x_dldev == (device_CPU, 0):
1151+
if x_dldev == cpu_dev:
11301152
dlpack_capsule = dlpack_attr()
11311153
host_blob = from_dlpack_capsule(dlpack_capsule)
11321154
else:
11331155
dlpack_capsule = dlpack_attr(
11341156
max_version=requested_ver,
1135-
dl_device=(device_CPU, 0),
1157+
dl_device=cpu_dev,
11361158
copy=copy
11371159
)
11381160
host_blob = from_dlpack_capsule(dlpack_capsule)

0 commit comments

Comments
 (0)