Skip to content

Commit 465f2f2

Browse files
implemented feature #396
1 parent d97370a commit 465f2f2

File tree

3 files changed

+282
-118
lines changed

3 files changed

+282
-118
lines changed

dpctl/memory/_memory.pyx

Lines changed: 5 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -35,68 +35,15 @@ from cpython cimport pycapsule
3535

3636
import numpy as np
3737
import numbers
38+
import collections
3839

3940
__all__ = [
4041
"MemoryUSMShared",
4142
"MemoryUSMHost",
4243
"MemoryUSMDevice"
4344
]
4445

45-
cdef object _sycl_usm_ary_iface_error():
46-
return ValueError("__sycl_usm_array_interface__ is malformed")
47-
48-
49-
cdef DPCTLSyclQueueRef _queue_ref_copy_from_SyclQueue(SyclQueue q):
50-
return DPCTLQueue_Copy(q.get_queue_ref())
51-
52-
53-
cdef DPCTLSyclQueueRef _queue_ref_copy_from_USMRef_and_SyclContext(
54-
DPCTLSyclUSMRef ptr, SyclContext ctx):
55-
""" Obtain device from pointer and sycl context, use
56-
context and device to create a queue from which this memory
57-
can be accessible.
58-
"""
59-
cdef SyclDevice dev = _Memory.get_pointer_device(ptr, ctx)
60-
cdef DPCTLSyclContextRef CRef = NULL
61-
cdef DPCTLSyclDeviceRef DRef = NULL
62-
CRef = ctx.get_context_ref()
63-
DRef = dev.get_device_ref()
64-
return DPCTLQueue_Create(CRef, DRef, NULL, 0)
65-
66-
67-
cdef DPCTLSyclQueueRef get_queue_ref_from_ptr_and_syclobj(
68-
DPCTLSyclUSMRef ptr, object syclobj):
69-
""" Constructs queue from pointer and syclobject from
70-
__sycl_usm_array_interface__
71-
"""
72-
cdef DPCTLSyclQueueRef QRef = NULL
73-
cdef SyclContext ctx
74-
if type(syclobj) is SyclQueue:
75-
return _queue_ref_copy_from_SyclQueue(<SyclQueue> syclobj)
76-
elif type(syclobj) is SyclContext:
77-
ctx = <SyclContext>syclobj
78-
return _queue_ref_copy_from_USMRef_and_SyclContext(ptr, ctx)
79-
elif type(syclobj) is str:
80-
q = SyclQueue(syclobj)
81-
return _queue_ref_copy_from_SyclQueue(<SyclQueue> q)
82-
elif pycapsule.PyCapsule_IsValid(syclobj, "SyclQueueRef"):
83-
q = SyclQueue(syclobj)
84-
return _queue_ref_copy_from_SyclQueue(<SyclQueue> q)
85-
elif pycapsule.PyCapsule_IsValid(syclobj, "SyclContextRef"):
86-
ctx = <SyclContext>SyclContext(syclobj)
87-
return _queue_ref_copy_from_USMRef_and_SyclContext(ptr, ctx)
88-
elif hasattr(syclobj, '_get_capsule'):
89-
cap = syclobj._get_capsule()
90-
if pycapsule.PyCapsule_IsValid(cap, "SyclQueueRef"):
91-
q = SyclQueue(cap)
92-
return _queue_ref_copy_from_SyclQueue(<SyclQueue> q)
93-
elif pycapsule.PyCapsule_IsValid(cap, "SyclContexRef"):
94-
ctx = <SyclContext>SyclContext(cap)
95-
return _queue_ref_copy_from_USMRef_and_SyclContext(ptr, ctx)
96-
else:
97-
return QRef
98-
else:
99-
return QRef
46+
include "_sycl_usm_array_interface_utils.pxi"
10047

10148

10249
cdef void copy_via_host(void *dest_ptr, SyclQueue dest_queue,
@@ -126,66 +73,6 @@ cdef void copy_via_host(void *dest_ptr, SyclQueue dest_queue,
12673
)
12774

12875

129-
cdef class _BufferData:
130-
"""
131-
Internal data struct populated from parsing
132-
`__sycl_usm_array_interface__` dictionary
133-
"""
134-
cdef DPCTLSyclUSMRef p
135-
cdef int writeable
136-
cdef object dt
137-
cdef Py_ssize_t itemsize
138-
cdef Py_ssize_t nbytes
139-
cdef SyclQueue queue
140-
141-
@staticmethod
142-
cdef _BufferData from_sycl_usm_ary_iface(dict ary_iface):
143-
cdef object ary_data_tuple = ary_iface.get('data', None)
144-
cdef object ary_typestr = ary_iface.get('typestr', None)
145-
cdef object ary_shape = ary_iface.get('shape', None)
146-
cdef object ary_strides = ary_iface.get('strides', None)
147-
cdef object ary_syclobj = ary_iface.get('syclobj', None)
148-
cdef Py_ssize_t ary_offset = ary_iface.get('offset', 0)
149-
cdef int ary_version = ary_iface.get('version', 0)
150-
cdef object dt
151-
cdef _BufferData buf
152-
cdef Py_ssize_t arr_data_ptr
153-
cdef SyclDevice dev
154-
cdef SyclContext ctx
155-
cdef DPCTLSyclQueueRef QRef = NULL
156-
157-
if ary_version != 1:
158-
raise _sycl_usm_ary_iface_error()
159-
if not ary_data_tuple or len(ary_data_tuple) != 2:
160-
raise _sycl_usm_ary_iface_error()
161-
if not ary_shape or len(ary_shape) != 1 or ary_shape[0] < 1:
162-
raise ValueError
163-
try:
164-
dt = np.dtype(ary_typestr)
165-
except TypeError:
166-
raise _sycl_usm_ary_iface_error()
167-
if (ary_strides and len(ary_strides) != 1
168-
and ary_strides[0] != dt.itemsize):
169-
raise ValueError("Must be contiguous")
170-
171-
if (not ary_syclobj or
172-
not isinstance(ary_syclobj,
173-
(dpctl.SyclQueue, dpctl.SyclContext))):
174-
raise _sycl_usm_ary_iface_error()
175-
176-
buf = _BufferData.__new__(_BufferData)
177-
arr_data_ptr = <Py_ssize_t>ary_data_tuple[0]
178-
buf.p = <DPCTLSyclUSMRef>(<void*>arr_data_ptr)
179-
buf.writeable = 1 if ary_data_tuple[1] else 0
180-
buf.itemsize = <Py_ssize_t>(dt.itemsize)
181-
buf.nbytes = (<Py_ssize_t>ary_shape[0]) * buf.itemsize
182-
183-
QRef = get_queue_ref_from_ptr_and_syclobj(buf.p, ary_syclobj)
184-
buf.queue = SyclQueue._create(QRef)
185-
186-
return buf
187-
188-
18976
def _to_memory(unsigned char [::1] b, str usm_kind):
19077
"""
19178
Constructs Memory of the same size as the argument
@@ -272,7 +159,7 @@ cdef class _Memory:
272159
elif hasattr(other, '__sycl_usm_array_interface__'):
273160
other_iface = other.__sycl_usm_array_interface__
274161
if isinstance(other_iface, dict):
275-
other_buf = _BufferData.from_sycl_usm_ary_iface(other_iface)
162+
other_buf = _USMBufferData.from_sycl_usm_ary_iface(other_iface)
276163
self.memory_ptr = other_buf.p
277164
self.nbytes = other_buf.nbytes
278165
self.queue = other_buf.queue
@@ -433,15 +320,15 @@ cdef class _Memory:
433320
cpdef copy_from_device(self, object sycl_usm_ary):
434321
"""Copy SYCL memory underlying the argument object into
435322
the memory of the instance"""
436-
cdef _BufferData src_buf
323+
cdef _USMBufferData src_buf
437324
cdef const char* kind
438325

439326
if not hasattr(sycl_usm_ary, '__sycl_usm_array_interface__'):
440327
raise ValueError("Object does not implement "
441328
"`__sycl_usm_array_interface__` protocol")
442329
sycl_usm_ary_iface = sycl_usm_ary.__sycl_usm_array_interface__
443330
if isinstance(sycl_usm_ary_iface, dict):
444-
src_buf = _BufferData.from_sycl_usm_ary_iface(sycl_usm_ary_iface)
331+
src_buf = _USMBufferData.from_sycl_usm_ary_iface(sycl_usm_ary_iface)
445332

446333
if (src_buf.nbytes > self.nbytes):
447334
raise ValueError("Source object is too large to "
Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
2+
3+
cdef bint _valid_usm_ptr_and_context(DPCTLSyclUSMRef ptr, SyclContext ctx):
4+
usm_type = _Memory.get_pointer_type(ptr, ctx)
5+
return usm_type in (b'shared', b'device', b'host')
6+
7+
8+
cdef DPCTLSyclQueueRef _queue_ref_copy_from_SyclQueue(
9+
DPCTLSyclUSMRef ptr, SyclQueue q):
10+
""" Check that USM ptr is consistent with SYCL context in the queue,
11+
and return a copy of QueueRef if so, or NULL otherwise.
12+
"""
13+
cdef SyclContext ctx = q.get_sycl_context()
14+
if (_valid_usm_ptr_and_context(ptr, ctx)):
15+
return DPCTLQueue_Copy(q.get_queue_ref())
16+
else:
17+
return NULL
18+
19+
20+
cdef DPCTLSyclQueueRef _queue_ref_copy_from_USMRef_and_SyclContext(
21+
DPCTLSyclUSMRef ptr, SyclContext ctx):
22+
""" Obtain device from pointer and sycl context, use
23+
context and device to create a queue from which this memory
24+
can be accessible.
25+
"""
26+
cdef SyclDevice dev = _Memory.get_pointer_device(ptr, ctx)
27+
cdef DPCTLSyclContextRef CRef = ctx.get_context_ref()
28+
cdef DPCTLSyclDeviceRef DRef = dev.get_device_ref()
29+
return DPCTLQueue_Create(CRef, DRef, NULL, 0)
30+
31+
32+
cdef DPCTLSyclQueueRef get_queue_ref_from_ptr_and_syclobj(
33+
DPCTLSyclUSMRef ptr, object syclobj):
34+
""" Constructs queue from pointer and syclobject from
35+
__sycl_usm_array_interface__
36+
"""
37+
cdef SyclContext ctx
38+
if type(syclobj) is SyclQueue:
39+
return _queue_ref_copy_from_SyclQueue(ptr, <SyclQueue> syclobj)
40+
elif type(syclobj) is SyclContext:
41+
ctx = <SyclContext>syclobj
42+
return _queue_ref_copy_from_USMRef_and_SyclContext(ptr, ctx)
43+
elif type(syclobj) is str:
44+
q = SyclQueue(syclobj)
45+
return _queue_ref_copy_from_SyclQueue(ptr, <SyclQueue> q)
46+
elif pycapsule.PyCapsule_IsValid(syclobj, "SyclQueueRef"):
47+
q = SyclQueue(syclobj)
48+
return _queue_ref_copy_from_SyclQueue(ptr, <SyclQueue> q)
49+
elif pycapsule.PyCapsule_IsValid(syclobj, "SyclContextRef"):
50+
ctx = <SyclContext>SyclContext(syclobj)
51+
return _queue_ref_copy_from_USMRef_and_SyclContext(ptr, ctx)
52+
elif hasattr(syclobj, '_get_capsule'):
53+
cap = syclobj._get_capsule()
54+
if pycapsule.PyCapsule_IsValid(cap, "SyclQueueRef"):
55+
q = SyclQueue(cap)
56+
return _queue_ref_copy_from_SyclQueue(ptr, <SyclQueue> q)
57+
elif pycapsule.PyCapsule_IsValid(cap, "SyclContexRef"):
58+
ctx = <SyclContext>SyclContext(cap)
59+
return _queue_ref_copy_from_USMRef_and_SyclContext(ptr, ctx)
60+
else:
61+
return NULL
62+
else:
63+
return NULL
64+
65+
66+
cdef object _pointers_from_shape_and_stride(
67+
int nd, object ary_shape, Py_ssize_t itemsize, Py_ssize_t ary_offset,
68+
object ary_strides):
69+
"""
70+
Internal utility: for given array data about shape/layout/element
71+
compute left-most displacement when enumerating all elements of the array
72+
and the number of bytes of memory between the left-most and right-most
73+
displacements.
74+
75+
Returns: tuple(min_disp, nbytes)
76+
"""
77+
if (nd > 0):
78+
if (ary_strides is None):
79+
nelems = 1
80+
for si in ary_shape:
81+
sh_i = int(si)
82+
if (sh_i <= 0):
83+
raise ValueError("Array shape elements need to be positive")
84+
nelems = nelems * sh_i
85+
return (ary_offset, nelems * itemsize)
86+
else:
87+
min_disp = ary_offset
88+
max_disp = ary_offset
89+
for i in range(nd):
90+
str_i = int(ary_strides[i])
91+
sh_i = int(ary_shape[i])
92+
if (str_i > 0):
93+
max_disp += str_i * (sh_i - 1)
94+
else:
95+
min_disp += str_i * (sh_i - 1);
96+
return (min_disp, (max_disp - min_disp + 1) * itemsize)
97+
elif (nd == 0):
98+
return (ary_offset, itemsize)
99+
else:
100+
raise ValueError("Array dimensions can not be negative")
101+
102+
103+
cdef class _USMBufferData:
104+
"""
105+
Internal data struct populated from parsing
106+
`__sycl_usm_array_interface__` dictionary
107+
"""
108+
cdef DPCTLSyclUSMRef p
109+
cdef int writeable
110+
cdef object dt
111+
cdef Py_ssize_t itemsize
112+
cdef Py_ssize_t nbytes
113+
cdef SyclQueue queue
114+
115+
@staticmethod
116+
cdef _USMBufferData from_sycl_usm_ary_iface(dict ary_iface):
117+
cdef object ary_data_tuple = ary_iface.get('data', None)
118+
cdef object ary_typestr = ary_iface.get('typestr', None)
119+
cdef object ary_shape = ary_iface.get('shape', None)
120+
cdef object ary_strides = ary_iface.get('strides', None)
121+
cdef object ary_syclobj = ary_iface.get('syclobj', None)
122+
cdef Py_ssize_t ary_offset = ary_iface.get('offset', 0)
123+
cdef int ary_version = ary_iface.get('version', 0)
124+
cdef Py_ssize_t arr_data_ptr = 0
125+
cdef DPCTLSyclUSMRef memRef = NULL
126+
cdef Py_ssize_t itemsize = -1
127+
cdef int writeable = -1
128+
cdef int nd = -1
129+
cdef DPCTLSyclQueueRef QRef = NULL
130+
cdef object dt
131+
cdef _USMBufferData buf
132+
cdef SyclDevice dev
133+
cdef SyclContext ctx
134+
135+
if ary_version != 1:
136+
raise ValueError(("__sycl_usm_array_interface__ is malformed:"
137+
" dict('version': {}) is unexpected."
138+
" The only recognized version is 1.").format(
139+
ary_version))
140+
if not ary_data_tuple or len(ary_data_tuple) != 2:
141+
raise ValueError("__sycl_usm_array_interface__ is malformed:"
142+
" 'data' field is required, and must be a tuple"
143+
" (usm_pointer, is_writeable_boolean).")
144+
arr_data_ptr = <Py_ssize_t>ary_data_tuple[0]
145+
writeable = 1 if ary_data_tuple[1] else 0
146+
# Check that memory and syclobj are consistent:
147+
# (USM pointer is bound to this sycl context)
148+
memRef = <DPCTLSyclUSMRef>arr_data_ptr
149+
QRef = get_queue_ref_from_ptr_and_syclobj(memRef, ary_syclobj)
150+
if (QRef is NULL):
151+
raise ValueError("__sycl_usm_array_interface__ is malformed:"
152+
" 'data' field is not consistent with 'syclobj'"
153+
" field, the pointer {} is not bound to"
154+
" SyclContext derived from"
155+
" dict('syclobj': {}).".format(
156+
hex(arr_data_ptr), ary_syclobj))
157+
# shape must be present
158+
if ary_shape is None or not (
159+
isinstance(ary_shape, collections.abc.Sized) and
160+
isinstance(ary_shape, collections.abc.Iterable)):
161+
DPCTLQueue_Delete(QRef)
162+
raise ValueError("Shape entry is a required element of "
163+
"`__sycl_usm_array_interface__` dictionary")
164+
nd = len(ary_shape)
165+
try:
166+
dt = np.dtype(ary_typestr)
167+
if (dt.hasobject or not (np.issubdtype(dt.type, np.integer) or
168+
np.issubdtype(dt.type, np.inexact))):
169+
DPCTLQueue_Delete(QRef)
170+
raise TypeError("Only integer types, floating and complex "
171+
"floating types are supported.")
172+
itemsize = <Py_ssize_t> (dt.itemsize)
173+
except TypeError as e:
174+
raise ValueError(
175+
"__sycl_usm_array_interface__ is malformed:"
176+
" dict('typestr': {}) is unexpected. ".format(ary_typestr)
177+
) from e
178+
179+
if (ary_strides is None or (
180+
isinstance(ary_strides, collections.abc.Sized) and
181+
isinstance(ary_strides, collections.abc.Iterable) and
182+
len(ary_strides) == nd)):
183+
min_disp, nbytes = _pointers_from_shape_and_stride(
184+
nd, ary_shape, itemsize, ary_offset, ary_strides)
185+
else:
186+
DPCTLQueue_Delete(QRef)
187+
raise ValueError("__sycl_usm_array_interface__ is malformed: "
188+
"'strides' must be a tuple or "
189+
"list of the same length as shape")
190+
191+
buf = _USMBufferData.__new__(_USMBufferData)
192+
buf.p = <DPCTLSyclUSMRef>(
193+
arr_data_ptr + (<Py_ssize_t>min_disp) * itemsize)
194+
buf.writeable = writeable
195+
buf.itemsize = itemsize
196+
buf.nbytes = <Py_ssize_t> nbytes
197+
198+
buf.queue = SyclQueue._create(QRef)
199+
200+
return buf

0 commit comments

Comments
 (0)