Skip to content

Commit aa9704c

Browse files
Adds dpctl.SyclQueue.memcpy_async
Also extends `dpctl.SyclQueue.memcpy` to allow arguments to be objects that expose buffer protocol, allowing `dpctl.SyclQueue.memcpy` and `dpctl.SyclQueue.memcpy_async` to be used to copy from/to USM-allocation or host buffer.
1 parent 41540ab commit aa9704c

File tree

4 files changed

+164
-29
lines changed

4 files changed

+164
-29
lines changed

dpctl/_backend.pxd

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,13 @@ cdef extern from "syclinterface/dpctl_sycl_queue_interface.h":
403403
void *Dest,
404404
const void *Src,
405405
size_t Count)
406+
cdef DPCTLSyclEventRef DPCTLQueue_MemcpyWithEvents(
407+
const DPCTLSyclQueueRef Q,
408+
void *Dest,
409+
const void *Src,
410+
size_t Count,
411+
const DPCTLSyclEventRef *depEvents,
412+
size_t depEventsCount)
406413
cdef DPCTLSyclEventRef DPCTLQueue_Memset(
407414
const DPCTLSyclQueueRef Q,
408415
void *Dest,

dpctl/_sycl_queue.pxd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ cdef public api class SyclQueue (_SyclQueue) [
9696
cpdef void wait(self)
9797
cdef DPCTLSyclQueueRef get_queue_ref(self)
9898
cpdef memcpy(self, dest, src, size_t count)
99-
cpdef SyclEvent memcpy_async(self, dest, src, size_t count)
99+
cpdef SyclEvent memcpy_async(self, dest, src, size_t count, list dEvents=*)
100100
cpdef prefetch(self, ptr, size_t count=*)
101101
cpdef mem_advise(self, ptr, size_t count, int mem)
102102
cpdef SyclEvent submit_barrier(self, dependent_events=*)

dpctl/_sycl_queue.pyx

Lines changed: 83 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ from ._backend cimport ( # noqa: E211
4545
DPCTLQueue_IsInOrder,
4646
DPCTLQueue_MemAdvise,
4747
DPCTLQueue_Memcpy,
48+
DPCTLQueue_MemcpyWithEvents,
4849
DPCTLQueue_Prefetch,
4950
DPCTLQueue_SubmitBarrierForEvents,
5051
DPCTLQueue_SubmitNDRange,
@@ -65,6 +66,7 @@ import ctypes
6566
from .enum_types import backend_type
6667

6768
from cpython cimport pycapsule
69+
from cpython.buffer cimport PyObject_CheckBuffer
6870
from cpython.ref cimport Py_DECREF, Py_INCREF, PyObject
6971
from libc.stdlib cimport free, malloc
7072

@@ -173,6 +175,62 @@ cdef void _queue_capsule_deleter(object o) noexcept:
173175
DPCTLQueue_Delete(QRef)
174176

175177

178+
cdef bint _is_buffer(object o):
179+
return PyObject_CheckBuffer(o)
180+
181+
182+
cdef DPCTLSyclEventRef _memcpy_impl(
183+
SyclQueue q,
184+
object dst,
185+
object src,
186+
size_t byte_count,
187+
DPCTLSyclEventRef *dep_events,
188+
size_t dep_events_count
189+
):
190+
cdef void *c_dst_ptr = NULL
191+
cdef void *c_src_ptr = NULL
192+
cdef DPCTLSyclEventRef ERef = NULL
193+
cdef const unsigned char[::1] src_host_buf = None
194+
cdef unsigned char[::1] dst_host_buf = None
195+
196+
if isinstance(src, _Memory):
197+
c_src_ptr = <void*>(<_Memory>src).memory_ptr
198+
elif _is_buffer(src):
199+
src_host_buf = src
200+
c_src_ptr = <void *>&src_host_buf[0]
201+
else:
202+
raise TypeError(
203+
"Parameter `src` should have either type "
204+
"`dpctl.memory._Memory` or a type that "
205+
"supports Python buffer protocol"
206+
)
207+
208+
if isinstance(dst, _Memory):
209+
c_dst_ptr = <void*>(<_Memory>dst).memory_ptr
210+
elif _is_buffer(dst):
211+
dst_host_buf = dst
212+
c_dst_ptr = <void *>&dst_host_buf[0]
213+
else:
214+
raise TypeError(
215+
"Parameter `dst` should have either type "
216+
"`dpctl.memory._Memory` or a type that "
217+
"supports Python buffer protocol"
218+
)
219+
220+
if dep_events_count == 0 or dep_events is NULL:
221+
ERef = DPCTLQueue_Memcpy(q._queue_ref, c_dst_ptr, c_src_ptr, byte_count)
222+
else:
223+
ERef = DPCTLQueue_MemcpyWithEvents(
224+
q._queue_ref,
225+
c_dst_ptr,
226+
c_src_ptr,
227+
byte_count,
228+
dep_events,
229+
dep_events_count
230+
)
231+
return ERef
232+
233+
176234
cdef class _SyclQueue:
177235
""" Barebone data owner class used by SyclQueue.
178236
"""
@@ -938,44 +996,44 @@ cdef class SyclQueue(_SyclQueue):
938996
with nogil: DPCTLQueue_Wait(self._queue_ref)
939997

940998
cpdef memcpy(self, dest, src, size_t count):
941-
cdef void *c_dest
942-
cdef void *c_src
999+
"""Copy memory from `src` to `dst`"""
9431000
cdef DPCTLSyclEventRef ERef = NULL
9441001

945-
if isinstance(dest, _Memory):
946-
c_dest = <void*>(<_Memory>dest).memory_ptr
947-
else:
948-
raise TypeError("Parameter `dest` should have type _Memory.")
949-
950-
if isinstance(src, _Memory):
951-
c_src = <void*>(<_Memory>src).memory_ptr
952-
else:
953-
raise TypeError("Parameter `src` should have type _Memory.")
954-
955-
ERef = DPCTLQueue_Memcpy(self._queue_ref, c_dest, c_src, count)
1002+
ERef = _memcpy_impl(<SyclQueue>self, dest, src, count, NULL, 0)
9561003
if (ERef is NULL):
9571004
raise RuntimeError(
9581005
"SyclQueue.memcpy operation encountered an error"
9591006
)
9601007
with nogil: DPCTLEvent_Wait(ERef)
9611008
DPCTLEvent_Delete(ERef)
9621009

963-
cpdef SyclEvent memcpy_async(self, dest, src, size_t count):
964-
cdef void *c_dest
965-
cdef void *c_src
1010+
cpdef SyclEvent memcpy_async(self, dest, src, size_t count, list dEvents=None):
1011+
"""Copy memory from `src` to `dst`"""
9661012
cdef DPCTLSyclEventRef ERef = NULL
1013+
cdef DPCTLSyclEventRef *depEvents = NULL
1014+
cdef size_t nDE = 0
9671015

968-
if isinstance(dest, _Memory):
969-
c_dest = <void*>(<_Memory>dest).memory_ptr
970-
else:
971-
raise TypeError("Parameter `dest` should have type _Memory.")
972-
973-
if isinstance(src, _Memory):
974-
c_src = <void*>(<_Memory>src).memory_ptr
1016+
if dEvents is None:
1017+
ERef = _memcpy_impl(<SyclQueue>self, dest, src, count, NULL, 0)
9751018
else:
976-
raise TypeError("Parameter `src` should have type _Memory.")
1019+
nDE = len(dEvents)
1020+
depEvents = (
1021+
<DPCTLSyclEventRef*>malloc(nDE*sizeof(DPCTLSyclEventRef))
1022+
)
1023+
if depEvents is NULL:
1024+
raise MemoryError()
1025+
else:
1026+
for idx, de in enumerate(dEvents):
1027+
if isinstance(de, SyclEvent):
1028+
depEvents[idx] = (<SyclEvent>de).get_event_ref()
1029+
else:
1030+
free(depEvents)
1031+
raise TypeError(
1032+
"A sequence of dpctl.SyclEvent is expected"
1033+
)
1034+
ERef = _memcpy_impl(self, dest, src, count, depEvents, nDE)
1035+
free(depEvents)
9771036

978-
ERef = DPCTLQueue_Memcpy(self._queue_ref, c_dest, c_src, count)
9791037
if (ERef is NULL):
9801038
raise RuntimeError(
9811039
"SyclQueue.memcpy operation encountered an error"

dpctl/tests/test_sycl_queue_memcpy.py

Lines changed: 73 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,77 @@ def test_memcpy_copy_usm_to_usm():
4444

4545
q.memcpy(mobj2, mobj1, 3)
4646

47-
assert mv2[:3], b"123"
47+
assert mv2[:3] == b"123"
48+
49+
50+
def test_memcpy_copy_host_to_usm():
51+
try:
52+
q = dpctl.SyclQueue()
53+
except dpctl.SyclQueueCreationError:
54+
pytest.skip("Default constructor for SyclQueue failed")
55+
usm_obj = _create_memory(q)
56+
57+
canary = bytearray(b"123456789")
58+
host_obj = memoryview(canary)
59+
60+
q.memcpy(usm_obj, host_obj, len(canary))
61+
62+
mv2 = memoryview(usm_obj)
63+
64+
assert mv2[: len(canary)] == canary
65+
66+
67+
def test_memcpy_copy_usm_to_host():
68+
try:
69+
q = dpctl.SyclQueue()
70+
except dpctl.SyclQueueCreationError:
71+
pytest.skip("Default constructor for SyclQueue failed")
72+
usm_obj = _create_memory(q)
73+
mv2 = memoryview(usm_obj)
74+
75+
n = 9
76+
for id in range(n):
77+
mv2[id] = ord("a") + id
78+
79+
host_obj = bytearray(b" " * n)
80+
81+
q.memcpy(host_obj, usm_obj, n)
82+
83+
assert host_obj == b"abcdefghi"
84+
85+
86+
def test_memcpy_copy_host_to_host():
87+
try:
88+
q = dpctl.SyclQueue()
89+
except dpctl.SyclQueueCreationError:
90+
pytest.skip("Default constructor for SyclQueue failed")
91+
92+
src_buf = b"abcdefghijklmnopqrstuvwxyz"
93+
dst_buf = bytearray(len(src_buf))
94+
95+
q.memcpy(dst_buf, src_buf, len(src_buf))
96+
97+
assert dst_buf == src_buf
98+
99+
100+
def test_memcpy_async():
101+
try:
102+
q = dpctl.SyclQueue()
103+
except dpctl.SyclQueueCreationError:
104+
pytest.skip("Default constructor for SyclQueue failed")
105+
106+
src_buf = b"abcdefghijklmnopqrstuvwxyz"
107+
n = len(src_buf)
108+
dst_buf = bytearray(n)
109+
dst_buf2 = bytearray(n)
110+
111+
e = q.memcpy_async(dst_buf, src_buf, n)
112+
e2 = q.memcpy_async(dst_buf2, src_buf, n)
113+
114+
e2.wait()
115+
e.wait()
116+
assert dst_buf == src_buf
117+
assert dst_buf2 == src_buf
48118

49119

50120
def test_memcpy_type_error():
@@ -56,8 +126,8 @@ def test_memcpy_type_error():
56126

57127
with pytest.raises(TypeError) as cm:
58128
q.memcpy(None, mobj, 3)
59-
assert "`dest`" in str(cm.value)
129+
assert "_Memory" in str(cm.value)
60130

61131
with pytest.raises(TypeError) as cm:
62132
q.memcpy(mobj, None, 3)
63-
assert "`src`" in str(cm.value)
133+
assert "_Memory" in str(cm.value)

0 commit comments

Comments
 (0)