Skip to content

Commit 4076fb2

Browse files
Merge pull request #380 from IntelPython/feature/suai-helper
Added get_queue_ref_from_ptr_and_syclobj
2 parents a8ea6ef + b754dba commit 4076fb2

File tree

2 files changed

+68
-16
lines changed

2 files changed

+68
-16
lines changed

dpctl/memory/_memory.pxd

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,16 @@ in dpctl.memory._memory.pyx.
2222
2323
"""
2424

25-
from .._backend cimport DPCTLSyclUSMRef
25+
from .._backend cimport DPCTLSyclUSMRef, DPCTLSyclQueueRef
2626
from .._sycl_context cimport SyclContext
2727
from .._sycl_device cimport SyclDevice
2828
from .._sycl_queue cimport SyclQueue
2929

3030

31+
cdef DPCTLSyclQueueRef get_queue_ref_from_ptr_and_syclobj(
32+
DPCTLSyclUSMRef ptr, object syclobj)
33+
34+
3135
cdef public class _Memory [object Py_MemoryObject, type Py_MemoryType]:
3236
cdef DPCTLSyclUSMRef memory_ptr
3337
cdef Py_ssize_t nbytes

dpctl/memory/_memory.pyx

Lines changed: 63 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ from .._sycl_queue_manager cimport get_current_queue
3232

3333
from cpython cimport Py_buffer
3434
from cpython.bytes cimport PyBytes_AS_STRING, PyBytes_FromStringAndSize
35+
from cpython cimport pycapsule
3536

3637
import numpy as np
3738

@@ -41,10 +42,63 @@ __all__ = [
4142
"MemoryUSMDevice"
4243
]
4344

44-
cdef _throw_sycl_usm_ary_iface():
45-
raise ValueError("__sycl_usm_array_interface__ is malformed")
45+
cdef object _sycl_usm_ary_iface_error():
46+
return ValueError("__sycl_usm_array_interface__ is malformed")
4647

4748

49+
cdef DPCTLSyclQueueRef _queue_ref_copy_from_SyclQueue(SyclQueue q):
50+
return DPCTLQueue_Copy(q.get_queue_ref())
51+
52+
53+
cdef DPCTLSyclQueueRef _queue_ref_copy_from_USMRef_and_SyclContext(
54+
DPCTLSyclUSMRef ptr, SyclContext ctx):
55+
""" Obtain device from pointer and sycl context, use
56+
context and device to create a queue from which this memory
57+
can be accessible.
58+
"""
59+
cdef SyclDevice dev = _Memory.get_pointer_device(ptr, ctx)
60+
cdef DPCTLSyclContextRef CRef = NULL
61+
cdef DPCTLSyclDeviceRef DRef = NULL
62+
CRef = ctx.get_context_ref()
63+
DRef = dev.get_device_ref()
64+
return DPCTLQueue_Create(CRef, DRef, NULL, 0)
65+
66+
67+
cdef DPCTLSyclQueueRef get_queue_ref_from_ptr_and_syclobj(
68+
DPCTLSyclUSMRef ptr, object syclobj):
69+
""" Constructs queue from pointer and syclobject from
70+
__sycl_usm_array_interface__
71+
"""
72+
cdef DPCTLSyclQueueRef QRef = NULL
73+
cdef SyclContext ctx
74+
if type(syclobj) is SyclQueue:
75+
return _queue_ref_copy_from_SyclQueue(<SyclQueue> syclobj)
76+
elif type(syclobj) is SyclContext:
77+
ctx = <SyclContext>syclobj
78+
return _queue_ref_copy_from_USMRef_and_SyclContext(ptr, ctx)
79+
elif type(syclobj) is str:
80+
q = SyclQueue(syclobj)
81+
return _queue_ref_copy_from_SyclQueue(<SyclQueue> q)
82+
elif pycapsule.PyCapsule_IsValid(syclobj, "SyclQueueRef"):
83+
q = SyclQueue(syclobj)
84+
return _queue_ref_copy_from_SyclQueue(<SyclQueue> q)
85+
elif pycapsule.PyCapsule_IsValid(syclobj, "SyclContextRef"):
86+
ctx = <SyclContext>SyclContext(syclobj)
87+
return _queue_ref_copy_from_USMRef_and_SyclContext(ptr, ctx)
88+
elif hasattr(syclobj, '_get_capsule'):
89+
cap = syclobj._get_capsule()
90+
if pycapsule.PyCapsule_IsValid(cap, "SyclQueueRef"):
91+
q = SyclQueue(cap)
92+
return _queue_ref_copy_from_SyclQueue(<SyclQueue> q)
93+
elif pycapsule.PyCapsule_IsValid(cap, "SyclContexRef"):
94+
ctx = <SyclContext>SyclContext(cap)
95+
return _queue_ref_copy_from_USMRef_and_SyclContext(ptr, ctx)
96+
else:
97+
return QRef
98+
else:
99+
return QRef
100+
101+
48102
cdef void copy_via_host(void *dest_ptr, SyclQueue dest_queue,
49103
void *src_ptr, SyclQueue src_queue, size_t nbytes):
50104
"""
@@ -98,25 +152,26 @@ cdef class _BufferData:
98152
cdef Py_ssize_t arr_data_ptr
99153
cdef SyclDevice dev
100154
cdef SyclContext ctx
155+
cdef DPCTLSyclQueueRef QRef = NULL
101156

102157
if ary_version != 1:
103-
_throw_sycl_usm_ary_iface()
158+
raise _sycl_usm_ary_iface_error()
104159
if not ary_data_tuple or len(ary_data_tuple) != 2:
105-
_throw_sycl_usm_ary_iface()
160+
raise _sycl_usm_ary_iface_error()
106161
if not ary_shape or len(ary_shape) != 1 or ary_shape[0] < 1:
107162
raise ValueError
108163
try:
109164
dt = np.dtype(ary_typestr)
110165
except TypeError:
111-
_throw_sycl_usm_ary_iface()
166+
raise _sycl_usm_ary_iface_error()
112167
if (ary_strides and len(ary_strides) != 1
113168
and ary_strides[0] != dt.itemsize):
114169
raise ValueError("Must be contiguous")
115170

116171
if (not ary_syclobj or
117172
not isinstance(ary_syclobj,
118173
(dpctl.SyclQueue, dpctl.SyclContext))):
119-
_throw_sycl_usm_ary_iface()
174+
raise _sycl_usm_ary_iface_error()
120175

121176
buf = _BufferData.__new__(_BufferData)
122177
arr_data_ptr = <Py_ssize_t>ary_data_tuple[0]
@@ -125,15 +180,8 @@ cdef class _BufferData:
125180
buf.itemsize = <Py_ssize_t>(dt.itemsize)
126181
buf.nbytes = (<Py_ssize_t>ary_shape[0]) * buf.itemsize
127182

128-
if isinstance(ary_syclobj, dpctl.SyclQueue):
129-
buf.queue = <SyclQueue>ary_syclobj
130-
else:
131-
# Obtain device from pointer and context
132-
ctx = <SyclContext> ary_syclobj
133-
dev = _Memory.get_pointer_device(buf.p, ctx)
134-
# Use context and device to create a queue to
135-
# be able to copy memory
136-
buf.queue = SyclQueue._create_from_context_and_device(ctx, dev)
183+
QRef = get_queue_ref_from_ptr_and_syclobj(buf.p, ary_syclobj)
184+
buf.queue = SyclQueue._create(QRef)
137185

138186
return buf
139187

0 commit comments

Comments
 (0)