Skip to content

Commit 530a7d7

Browse files
Merge pull request #984 from IntelPython/extend-dlpack-for-update-in-compiler
Added support for consuming DLPack allocated on a sub-device
2 parents c1335f1 + 7bd938d commit 530a7d7

File tree

3 files changed

+73
-32
lines changed

3 files changed

+73
-32
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2121
* Introduced `"syclinterface/dpctl_sycl_types_casters.hpp"` header file with declaration of conversion routines between SYCL type pointers and SyclInterface library opaque pointers [#960](https://github.com/IntelPython/dpctl/pull/960).
2222
* Added C-API to `dpctl.program.SyclKernel` and `dpctl.program.SyclProgram`. Added type casters for new types to "dpctl4pybind11" and added an example demonstrating its use [#970](https://github.com/IntelPython/dpctl/pull/970).
2323
* Introduced "dpctl/sycl.pxd" Cython declaration file to streamline use of SYCL functions from Cython, and added an example demonstrating its use [#981](https://github.com/IntelPython/dpctl/pull/981).
24+
* Added experimental support for sharing data allocated on sub-devices via dlpack [#984](https://github.com/IntelPython/dpctl/pull/984).
2425

2526
### Changed
2627
* Improved queue compatibility testing in `dpctl.tensor`'s implementation module [#900](https://github.com/IntelPython/dpctl/pull/900).

dpctl/_sycl_platform.pyx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ cdef class SyclPlatform(_SyclPlatform):
272272
)
273273

274274
if (CRef == NULL):
275-
raise
275+
raise RuntimeError("Getting default error ran into a problem")
276276
else:
277277
return SyclContext._create(CRef)
278278

dpctl/tensor/_dlpack.pyx

Lines changed: 71 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,18 @@ from .._backend cimport (
3333
)
3434
from ._usmarray cimport usm_ndarray
3535

36+
from platform import system as sys_platform
37+
3638
import numpy as np
3739

3840
import dpctl
3941
import dpctl.memory as dpmem
4042

4143

44+
cdef bint _IS_LINUX = sys_platform() == "Linux"
45+
46+
del sys_platform
47+
4248
cdef extern from 'dlpack/dlpack.h' nogil:
4349
cdef int DLPACK_VERSION
4450

@@ -140,6 +146,7 @@ cpdef to_dlpack_capsule(usm_ndarray usm_ary) except+:
140146
cdef c_dpctl.SyclQueue ary_sycl_queue
141147
cdef c_dpctl.SyclDevice ary_sycl_device
142148
cdef DPCTLSyclDeviceRef pDRef = NULL
149+
cdef DPCTLSyclDeviceRef tDRef = NULL
143150
cdef DLManagedTensor *dlm_tensor = NULL
144151
cdef DLTensor *dl_tensor = NULL
145152
cdef int nd = usm_ary.get_ndim()
@@ -157,19 +164,45 @@ cpdef to_dlpack_capsule(usm_ndarray usm_ary) except+:
157164
ary_sycl_queue = usm_ary.get_sycl_queue()
158165
ary_sycl_device = ary_sycl_queue.get_sycl_device()
159166

160-
# check that ary_sycl_device is a non-partitioned device
161-
pDRef = DPCTLDevice_GetParentDevice(ary_sycl_device.get_device_ref())
162-
if pDRef is not NULL:
163-
DPCTLDevice_Delete(pDRef)
164-
raise DLPackCreationError(
165-
"to_dlpack_capsule: DLPack can only export arrays allocated on "
166-
"non-partitioned SYCL devices."
167-
)
168-
default_context = dpctl.SyclQueue(ary_sycl_device).sycl_context
169-
if not usm_ary.sycl_context == default_context:
167+
try:
168+
if _IS_LINUX:
169+
default_context = ary_sycl_device.sycl_platform.default_context
170+
else:
171+
default_context = None
172+
except RuntimeError:
173+
# RT does not support default_context, e.g. Windows
174+
default_context = None
175+
if default_context is None:
176+
# check that ary_sycl_device is a non-partitioned device
177+
pDRef = DPCTLDevice_GetParentDevice(ary_sycl_device.get_device_ref())
178+
if pDRef is not NULL:
179+
DPCTLDevice_Delete(pDRef)
180+
raise DLPackCreationError(
181+
"to_dlpack_capsule: DLPack can only export arrays allocated "
182+
"on non-partitioned SYCL devices on platforms where "
183+
"default_context oneAPI extension is not supported."
184+
)
185+
else:
186+
if not usm_ary.sycl_context == default_context:
187+
raise DLPackCreationError(
188+
"to_dlpack_capsule: DLPack can only export arrays based on USM "
189+
"allocations bound to a default platform SYCL context"
190+
)
191+
# Find the unpartitioned parent of the allocation device
192+
pDRef = DPCTLDevice_GetParentDevice(ary_sycl_device.get_device_ref())
193+
if pDRef is not NULL:
194+
tDRef = DPCTLDevice_GetParentDevice(pDRef)
195+
while tDRef is not NULL:
196+
DPCTLDevice_Delete(pDRef)
197+
pDRef = tDRef
198+
tDRef = DPCTLDevice_GetParentDevice(pDRef)
199+
ary_sycl_device = c_dpctl.SyclDevice._create(pDRef)
200+
201+
# Find ordinal number of the parent device
202+
device_id = ary_sycl_device.get_overall_ordinal()
203+
if device_id < 0:
170204
raise DLPackCreationError(
171-
"to_dlpack_capsule: DLPack can only export arrays based on USM "
172-
"allocations bound to a default platform SYCL context"
205+
"to_dlpack_capsule: failed to determine device_id"
173206
)
174207

175208
dlm_tensor = <DLManagedTensor *> stdlib.malloc(
@@ -192,14 +225,6 @@ cpdef to_dlpack_capsule(usm_ndarray usm_ary) except+:
192225
for i in range(nd):
193226
shape_strides_ptr[nd + i] = strides_ptr[i]
194227

195-
device_id = ary_sycl_device.get_overall_ordinal()
196-
if device_id < 0:
197-
stdlib.free(shape_strides_ptr)
198-
stdlib.free(dlm_tensor)
199-
raise DLPackCreationError(
200-
"to_dlpack_capsule: failed to determine device_id"
201-
)
202-
203228
ary_dt = usm_ary.dtype
204229
ary_dtk = ary_dt.kind
205230
element_offset = usm_ary.get_offset()
@@ -278,15 +303,16 @@ cpdef usm_ndarray from_dlpack_capsule(object py_caps) except +:
278303
success.
279304
Raises:
280305
TypeError: if argument is not a "dltensor" capsule.
281-
ValueError: if argument is "used_dltensor" capsule,
282-
if the USM pointer is not bound to the reconstructed
306+
ValueError: if argument is "used_dltensor" capsule
307+
BufferError: if the USM pointer is not bound to the reconstructed
283308
sycl context, or the DLPack's device_type is not supported
284309
by dpctl.
285310
"""
286311
cdef DLManagedTensor *dlm_tensor = NULL
287312
cdef bytes usm_type
288313
cdef size_t sz = 1
289314
cdef int i
315+
cdef int device_id = -1
290316
cdef int element_bytesize = 0
291317
cdef Py_ssize_t offset_min = 0
292318
cdef Py_ssize_t offset_max = 0
@@ -308,26 +334,40 @@ cpdef usm_ndarray from_dlpack_capsule(object py_caps) except +:
308334
py_caps, "dltensor")
309335
# Verify that we can work with this device
310336
if dlm_tensor.dl_tensor.device.device_type == kDLOneAPI:
311-
q = dpctl.SyclQueue(str(<int>dlm_tensor.dl_tensor.device.device_id))
337+
device_id = dlm_tensor.dl_tensor.device.device_id
338+
root_device = dpctl.SyclDevice(str(<int>device_id))
339+
try:
340+
if _IS_LINUX:
341+
default_context = root_device.sycl_platform.default_context
342+
else:
343+
default_context = dpctl.SyclQueue(root_device).sycl_context
344+
except RuntimeError:
345+
default_context = dpctl.SyclQueue(root_device).sycl_context
312346
if dlm_tensor.dl_tensor.data is NULL:
313347
usm_type = b"device"
348+
q = dpctl.SyclQueue(default_context, root_device)
314349
else:
315350
usm_type = c_dpmem._Memory.get_pointer_type(
316351
<DPCTLSyclUSMRef> dlm_tensor.dl_tensor.data,
317-
<c_dpctl.SyclContext>q.sycl_context)
318-
if usm_type == b"unknown":
319-
raise ValueError(
320-
f"Data pointer in DLPack is not bound to default sycl "
321-
"context of device '{device_id}', translated to "
322-
"{q.sycl_device.filter_string}"
352+
<c_dpctl.SyclContext>default_context)
353+
if usm_type == b"unknown":
354+
raise BufferError(
355+
"Data pointer in DLPack is not bound to default sycl "
356+
f"context of device '{device_id}', translated to "
357+
f"{root_device.filter_string}"
358+
)
359+
alloc_device = c_dpmem._Memory.get_pointer_device(
360+
<DPCTLSyclUSMRef> dlm_tensor.dl_tensor.data,
361+
<c_dpctl.SyclContext>default_context
323362
)
363+
q = dpctl.SyclQueue(default_context, alloc_device)
324364
if dlm_tensor.dl_tensor.dtype.bits % 8:
325-
raise ValueError(
365+
raise BufferError(
326366
"Can not import DLPack tensor whose element's "
327367
"bitsize is not a multiple of 8"
328368
)
329369
if dlm_tensor.dl_tensor.dtype.lanes != 1:
330-
raise ValueError(
370+
raise BufferError(
331371
"Can not import DLPack tensor with lanes != 1"
332372
)
333373
if dlm_tensor.dl_tensor.strides is NULL:

0 commit comments

Comments
 (0)