@@ -33,12 +33,18 @@ from .._backend cimport (
33
33
)
34
34
from ._usmarray cimport usm_ndarray
35
35
36
+ from platform import system as sys_platform
37
+
36
38
import numpy as np
37
39
38
40
import dpctl
39
41
import dpctl.memory as dpmem
40
42
41
43
44
+ cdef bint _IS_LINUX = sys_platform() == " Linux"
45
+
46
+ del sys_platform
47
+
42
48
cdef extern from ' dlpack/dlpack.h' nogil:
43
49
cdef int DLPACK_VERSION
44
50
@@ -140,6 +146,7 @@ cpdef to_dlpack_capsule(usm_ndarray usm_ary) except+:
140
146
cdef c_dpctl.SyclQueue ary_sycl_queue
141
147
cdef c_dpctl.SyclDevice ary_sycl_device
142
148
cdef DPCTLSyclDeviceRef pDRef = NULL
149
+ cdef DPCTLSyclDeviceRef tDRef = NULL
143
150
cdef DLManagedTensor * dlm_tensor = NULL
144
151
cdef DLTensor * dl_tensor = NULL
145
152
cdef int nd = usm_ary.get_ndim()
@@ -157,19 +164,45 @@ cpdef to_dlpack_capsule(usm_ndarray usm_ary) except+:
157
164
ary_sycl_queue = usm_ary.get_sycl_queue()
158
165
ary_sycl_device = ary_sycl_queue.get_sycl_device()
159
166
160
- # check that ary_sycl_device is a non-partitioned device
161
- pDRef = DPCTLDevice_GetParentDevice(ary_sycl_device.get_device_ref())
162
- if pDRef is not NULL :
163
- DPCTLDevice_Delete(pDRef)
164
- raise DLPackCreationError(
165
- " to_dlpack_capsule: DLPack can only export arrays allocated on "
166
- " non-partitioned SYCL devices."
167
- )
168
- default_context = dpctl.SyclQueue(ary_sycl_device).sycl_context
169
- if not usm_ary.sycl_context == default_context:
167
+ try :
168
+ if _IS_LINUX:
169
+ default_context = ary_sycl_device.sycl_platform.default_context
170
+ else :
171
+ default_context = None
172
+ except RuntimeError :
173
+ # RT does not support default_context, e.g. Windows
174
+ default_context = None
175
+ if default_context is None :
176
+ # check that ary_sycl_device is a non-partitioned device
177
+ pDRef = DPCTLDevice_GetParentDevice(ary_sycl_device.get_device_ref())
178
+ if pDRef is not NULL :
179
+ DPCTLDevice_Delete(pDRef)
180
+ raise DLPackCreationError(
181
+ " to_dlpack_capsule: DLPack can only export arrays allocated "
182
+ " on non-partitioned SYCL devices on platforms where "
183
+ " default_context oneAPI extension is not supported."
184
+ )
185
+ else :
186
+ if not usm_ary.sycl_context == default_context:
187
+ raise DLPackCreationError(
188
+ " to_dlpack_capsule: DLPack can only export arrays based on USM "
189
+ " allocations bound to a default platform SYCL context"
190
+ )
191
+ # Find the unpartitioned parent of the allocation device
192
+ pDRef = DPCTLDevice_GetParentDevice(ary_sycl_device.get_device_ref())
193
+ if pDRef is not NULL :
194
+ tDRef = DPCTLDevice_GetParentDevice(pDRef)
195
+ while tDRef is not NULL :
196
+ DPCTLDevice_Delete(pDRef)
197
+ pDRef = tDRef
198
+ tDRef = DPCTLDevice_GetParentDevice(pDRef)
199
+ ary_sycl_device = c_dpctl.SyclDevice._create(pDRef)
200
+
201
+ # Find ordinal number of the parent device
202
+ device_id = ary_sycl_device.get_overall_ordinal()
203
+ if device_id < 0 :
170
204
raise DLPackCreationError(
171
- " to_dlpack_capsule: DLPack can only export arrays based on USM "
172
- " allocations bound to a default platform SYCL context"
205
+ " to_dlpack_capsule: failed to determine device_id"
173
206
)
174
207
175
208
dlm_tensor = < DLManagedTensor * > stdlib.malloc(
@@ -192,14 +225,6 @@ cpdef to_dlpack_capsule(usm_ndarray usm_ary) except+:
192
225
for i in range (nd):
193
226
shape_strides_ptr[nd + i] = strides_ptr[i]
194
227
195
- device_id = ary_sycl_device.get_overall_ordinal()
196
- if device_id < 0 :
197
- stdlib.free(shape_strides_ptr)
198
- stdlib.free(dlm_tensor)
199
- raise DLPackCreationError(
200
- " to_dlpack_capsule: failed to determine device_id"
201
- )
202
-
203
228
ary_dt = usm_ary.dtype
204
229
ary_dtk = ary_dt.kind
205
230
element_offset = usm_ary.get_offset()
@@ -278,15 +303,16 @@ cpdef usm_ndarray from_dlpack_capsule(object py_caps) except +:
278
303
success.
279
304
Raises:
280
305
TypeError: if argument is not a "dltensor" capsule.
281
- ValueError: if argument is "used_dltensor" capsule,
282
- if the USM pointer is not bound to the reconstructed
306
+ ValueError: if argument is "used_dltensor" capsule
307
+ BufferError: if the USM pointer is not bound to the reconstructed
283
308
sycl context, or the DLPack's device_type is not supported
284
309
by dpctl.
285
310
"""
286
311
cdef DLManagedTensor * dlm_tensor = NULL
287
312
cdef bytes usm_type
288
313
cdef size_t sz = 1
289
314
cdef int i
315
+ cdef int device_id = - 1
290
316
cdef int element_bytesize = 0
291
317
cdef Py_ssize_t offset_min = 0
292
318
cdef Py_ssize_t offset_max = 0
@@ -308,26 +334,40 @@ cpdef usm_ndarray from_dlpack_capsule(object py_caps) except +:
308
334
py_caps, " dltensor" )
309
335
# Verify that we can work with this device
310
336
if dlm_tensor.dl_tensor.device.device_type == kDLOneAPI:
311
- q = dpctl.SyclQueue(str (< int > dlm_tensor.dl_tensor.device.device_id))
337
+ device_id = dlm_tensor.dl_tensor.device.device_id
338
+ root_device = dpctl.SyclDevice(str (< int > device_id))
339
+ try :
340
+ if _IS_LINUX:
341
+ default_context = root_device.sycl_platform.default_context
342
+ else :
343
+ default_context = dpctl.SyclQueue(root_device).sycl_context
344
+ except RuntimeError :
345
+ default_context = dpctl.SyclQueue(root_device).sycl_context
312
346
if dlm_tensor.dl_tensor.data is NULL :
313
347
usm_type = b" device"
348
+ q = dpctl.SyclQueue(default_context, root_device)
314
349
else :
315
350
usm_type = c_dpmem._Memory.get_pointer_type(
316
351
< DPCTLSyclUSMRef> dlm_tensor.dl_tensor.data,
317
- < c_dpctl.SyclContext> q.sycl_context)
318
- if usm_type == b" unknown" :
319
- raise ValueError (
320
- f" Data pointer in DLPack is not bound to default sycl "
321
- " context of device '{device_id}', translated to "
322
- " {q.sycl_device.filter_string}"
352
+ < c_dpctl.SyclContext> default_context)
353
+ if usm_type == b" unknown" :
354
+ raise BufferError(
355
+ " Data pointer in DLPack is not bound to default sycl "
356
+ f" context of device '{device_id}', translated to "
357
+ f" {root_device.filter_string}"
358
+ )
359
+ alloc_device = c_dpmem._Memory.get_pointer_device(
360
+ < DPCTLSyclUSMRef> dlm_tensor.dl_tensor.data,
361
+ < c_dpctl.SyclContext> default_context
323
362
)
363
+ q = dpctl.SyclQueue(default_context, alloc_device)
324
364
if dlm_tensor.dl_tensor.dtype.bits % 8 :
325
- raise ValueError (
365
+ raise BufferError (
326
366
" Can not import DLPack tensor whose element's "
327
367
" bitsize is not a multiple of 8"
328
368
)
329
369
if dlm_tensor.dl_tensor.dtype.lanes != 1 :
330
- raise ValueError (
370
+ raise BufferError (
331
371
" Can not import DLPack tensor with lanes != 1"
332
372
)
333
373
if dlm_tensor.dl_tensor.strides is NULL :
0 commit comments