@@ -32,6 +32,7 @@ from .._sycl_queue_manager cimport get_current_queue
32
32
33
33
from cpython cimport Py_buffer
34
34
from cpython.bytes cimport PyBytes_AS_STRING, PyBytes_FromStringAndSize
35
+ from cpython cimport pycapsule
35
36
36
37
import numpy as np
37
38
@@ -41,10 +42,63 @@ __all__ = [
41
42
" MemoryUSMDevice"
42
43
]
43
44
44
- cdef _throw_sycl_usm_ary_iface ():
45
- raise ValueError (" __sycl_usm_array_interface__ is malformed" )
45
+ cdef object _sycl_usm_ary_iface_error ():
46
+ return ValueError (" __sycl_usm_array_interface__ is malformed" )
46
47
47
48
49
+ cdef DPCTLSyclQueueRef _queue_ref_copy_from_SyclQueue(SyclQueue q):
50
+ return DPCTLQueue_Copy(q.get_queue_ref())
51
+
52
+
53
+ cdef DPCTLSyclQueueRef _queue_ref_copy_from_USMRef_and_SyclContext(
54
+ DPCTLSyclUSMRef ptr, SyclContext ctx):
55
+ """ Obtain device from pointer and sycl context, use
56
+ context and device to create a queue from which this memory
57
+ can be accessible.
58
+ """
59
+ cdef SyclDevice dev = _Memory.get_pointer_device(ptr, ctx)
60
+ cdef DPCTLSyclContextRef CRef = NULL
61
+ cdef DPCTLSyclDeviceRef DRef = NULL
62
+ CRef = ctx.get_context_ref()
63
+ DRef = dev.get_device_ref()
64
+ return DPCTLQueue_Create(CRef, DRef, NULL , 0 )
65
+
66
+
67
+ cdef DPCTLSyclQueueRef get_queue_ref_from_ptr_and_syclobj(
68
+ DPCTLSyclUSMRef ptr, object syclobj):
69
+ """ Constructs queue from pointer and syclobject from
70
+ __sycl_usm_array_interface__
71
+ """
72
+ cdef DPCTLSyclQueueRef QRef = NULL
73
+ cdef SyclContext ctx
74
+ if type (syclobj) is SyclQueue:
75
+ return _queue_ref_copy_from_SyclQueue(< SyclQueue> syclobj)
76
+ elif type (syclobj) is SyclContext:
77
+ ctx = < SyclContext> syclobj
78
+ return _queue_ref_copy_from_USMRef_and_SyclContext(ptr, ctx)
79
+ elif type (syclobj) is str :
80
+ q = SyclQueue(syclobj)
81
+ return _queue_ref_copy_from_SyclQueue(< SyclQueue> q)
82
+ elif pycapsule.PyCapsule_IsValid(syclobj, " SyclQueueRef" ):
83
+ q = SyclQueue(syclobj)
84
+ return _queue_ref_copy_from_SyclQueue(< SyclQueue> q)
85
+ elif pycapsule.PyCapsule_IsValid(syclobj, " SyclContextRef" ):
86
+ ctx = < SyclContext> SyclContext(syclobj)
87
+ return _queue_ref_copy_from_USMRef_and_SyclContext(ptr, ctx)
88
+ elif hasattr (syclobj, ' _get_capsule' ):
89
+ cap = syclobj._get_capsule()
90
+ if pycapsule.PyCapsule_IsValid(cap, " SyclQueueRef" ):
91
+ q = SyclQueue(cap)
92
+ return _queue_ref_copy_from_SyclQueue(< SyclQueue> q)
93
+ elif pycapsule.PyCapsule_IsValid(cap, " SyclContexRef" ):
94
+ ctx = < SyclContext> SyclContext(cap)
95
+ return _queue_ref_copy_from_USMRef_and_SyclContext(ptr, ctx)
96
+ else :
97
+ return QRef
98
+ else :
99
+ return QRef
100
+
101
+
48
102
cdef void copy_via_host(void * dest_ptr, SyclQueue dest_queue,
49
103
void * src_ptr, SyclQueue src_queue, size_t nbytes):
50
104
"""
@@ -98,25 +152,26 @@ cdef class _BufferData:
98
152
cdef Py_ssize_t arr_data_ptr
99
153
cdef SyclDevice dev
100
154
cdef SyclContext ctx
155
+ cdef DPCTLSyclQueueRef QRef = NULL
101
156
102
157
if ary_version != 1 :
103
- _throw_sycl_usm_ary_iface ()
158
+ raise _sycl_usm_ary_iface_error ()
104
159
if not ary_data_tuple or len (ary_data_tuple) != 2 :
105
- _throw_sycl_usm_ary_iface ()
160
+ raise _sycl_usm_ary_iface_error ()
106
161
if not ary_shape or len (ary_shape) != 1 or ary_shape[0 ] < 1 :
107
162
raise ValueError
108
163
try :
109
164
dt = np.dtype(ary_typestr)
110
165
except TypeError :
111
- _throw_sycl_usm_ary_iface ()
166
+ raise _sycl_usm_ary_iface_error ()
112
167
if (ary_strides and len (ary_strides) != 1
113
168
and ary_strides[0 ] != dt.itemsize):
114
169
raise ValueError (" Must be contiguous" )
115
170
116
171
if (not ary_syclobj or
117
172
not isinstance (ary_syclobj,
118
173
(dpctl.SyclQueue, dpctl.SyclContext))):
119
- _throw_sycl_usm_ary_iface ()
174
+ raise _sycl_usm_ary_iface_error ()
120
175
121
176
buf = _BufferData.__new__ (_BufferData)
122
177
arr_data_ptr = < Py_ssize_t> ary_data_tuple[0 ]
@@ -125,15 +180,8 @@ cdef class _BufferData:
125
180
buf.itemsize = < Py_ssize_t> (dt.itemsize)
126
181
buf.nbytes = (< Py_ssize_t> ary_shape[0 ]) * buf.itemsize
127
182
128
- if isinstance (ary_syclobj, dpctl.SyclQueue):
129
- buf.queue = < SyclQueue> ary_syclobj
130
- else :
131
- # Obtain device from pointer and context
132
- ctx = < SyclContext> ary_syclobj
133
- dev = _Memory.get_pointer_device(buf.p, ctx)
134
- # Use context and device to create a queue to
135
- # be able to copy memory
136
- buf.queue = SyclQueue._create_from_context_and_device(ctx, dev)
183
+ QRef = get_queue_ref_from_ptr_and_syclobj(buf.p, ary_syclobj)
184
+ buf.queue = SyclQueue._create(QRef)
137
185
138
186
return buf
139
187
0 commit comments