@@ -985,7 +985,8 @@ def from_dlpack(x, /, *, device=None, copy=None):
985
985
returned by :attr:`dpctl.tensor.usm_ndarray.device`, or a
986
986
2-tuple matching the format of the output of the ``__dlpack_device__``
987
987
method, an integer enumerator representing the device type followed by
988
- an integer representing the index of the device.
988
+ an integer representing the index of the device. The only supported
989
+ :enum:`dpctl.tensor.DLDeviceType` types are "kDLCPU" and "kDLOneAPI".
989
990
Default: ``None``.
990
991
copy (bool, optional)
991
992
Boolean indicating whether or not to copy the input.
@@ -1034,6 +1035,7 @@ def from_dlpack(x, /, *, device=None, copy=None):
1034
1035
1035
1036
C = Container(dpt.linspace(0, 100, num=20, dtype="int16"))
1036
1037
X = dpt.from_dlpack(C)
1038
+ Y = dpt.from_dlpack(C, device=(dpt.DLDeviceType.kDLCPU, 0))
1037
1039
1038
1040
"""
1039
1041
dlpack_attr = getattr (x, " __dlpack__" , None )
@@ -1070,6 +1072,7 @@ def from_dlpack(x, /, *, device=None, copy=None):
1070
1072
saved_exception = None
1071
1073
# First DLPack version supporting dl_device, and copy
1072
1074
requested_ver = (1 , 0 )
1075
+ cpu_dev = (device_CPU, 0 )
1073
1076
try :
1074
1077
# setting max_version to minimal version that supports dl_device/copy keywords
1075
1078
dlpack_capsule = dlpack_attr(
@@ -1105,16 +1108,35 @@ def from_dlpack(x, /, *, device=None, copy=None):
1105
1108
# we can only support importing to OneAPI devices
1106
1109
# from host, or from another oneAPI device
1107
1110
is_supported_x_dldev = (
1108
- x_dldev == (device_CPU, 0 ) or
1111
+ x_dldev == cpu_dev or
1109
1112
(x_dldev[0 ] == device_OneAPI)
1110
1113
)
1111
- if is_supported_x_dldev and dl_device[0 ] == device_OneAPI:
1114
+ is_supported_dl_device = (
1115
+ dl_device == cpu_dev or
1116
+ dl_device[0 ] == device_OneAPI
1117
+ )
1118
+ if is_supported_x_dldev and is_supported_dl_device:
1112
1119
dlpack_capsule = dlpack_attr()
1113
- host_blob = from_dlpack_capsule(dlpack_capsule)
1120
+ blob = from_dlpack_capsule(dlpack_capsule)
1114
1121
else :
1115
1122
raise BufferError(f" Can not import to requested device {dl_device}" )
1116
1123
dev = _create_device(device, dl_device)
1117
- return _to_usm_ary_from_host_blob(host_blob, dev)
1124
+ if x_dldev == cpu_dev and dl_device == cpu_dev:
1125
+ # both source and destion are CPU
1126
+ return blob
1127
+ elif x_dldev == cpu_dev:
1128
+ # source is CPU, destingation is oneAPI
1129
+ return _to_usm_ary_from_host_blob(blob, dev)
1130
+ elif dl_device == cpu_dev:
1131
+ # source is oneAPI, destination is CPU
1132
+ cpu_caps = blob.__dlpack__(
1133
+ max_version = get_build_dlpack_version(),
1134
+ dl_device = cpu_dev
1135
+ )
1136
+ return from_dlpack_capsule(cpu_caps)
1137
+ else :
1138
+ import dpctl.tensor as dpt
1139
+ return dpt.asarray(blob, device = dev)
1118
1140
elif got_buffer_error:
1119
1141
# we are here, because dlpack_attr could not deal with requested dl_device,
1120
1142
# or copying was required
@@ -1126,13 +1148,13 @@ def from_dlpack(x, /, *, device=None, copy=None):
1126
1148
if dl_device[0 ] != device_OneAPI:
1127
1149
raise BufferError(f" Can not import to requested device {dl_device}" )
1128
1150
x_dldev = dlpack_dev_attr()
1129
- if x_dldev == (device_CPU, 0 ) :
1151
+ if x_dldev == cpu_dev :
1130
1152
dlpack_capsule = dlpack_attr()
1131
1153
host_blob = from_dlpack_capsule(dlpack_capsule)
1132
1154
else :
1133
1155
dlpack_capsule = dlpack_attr(
1134
1156
max_version = requested_ver,
1135
- dl_device = (device_CPU, 0 ) ,
1157
+ dl_device = cpu_dev ,
1136
1158
copy = copy
1137
1159
)
1138
1160
host_blob = from_dlpack_capsule(dlpack_capsule)
0 commit comments