@@ -45,6 +45,7 @@ from ._backend cimport ( # noqa: E211
45
45
DPCTLQueue_IsInOrder,
46
46
DPCTLQueue_MemAdvise,
47
47
DPCTLQueue_Memcpy,
48
+ DPCTLQueue_MemcpyWithEvents,
48
49
DPCTLQueue_Prefetch,
49
50
DPCTLQueue_SubmitBarrierForEvents,
50
51
DPCTLQueue_SubmitNDRange,
@@ -64,6 +65,7 @@ import ctypes
64
65
from .enum_types import backend_type
65
66
66
67
from cpython cimport pycapsule
68
+ from cpython.buffer cimport PyObject_CheckBuffer
67
69
from cpython.ref cimport Py_DECREF, Py_INCREF, PyObject
68
70
from libc.stdlib cimport free, malloc
69
71
@@ -160,6 +162,62 @@ cdef void _queue_capsule_deleter(object o) noexcept:
160
162
DPCTLQueue_Delete(QRef)
161
163
162
164
165
+ cdef bint _is_buffer(object o):
166
+ return PyObject_CheckBuffer(o)
167
+
168
+
169
+ cdef DPCTLSyclEventRef _memcpy_impl(
170
+ SyclQueue q,
171
+ object dst,
172
+ object src,
173
+ size_t byte_count,
174
+ DPCTLSyclEventRef * dep_events,
175
+ size_t dep_events_count
176
+ ):
177
+ cdef void * c_dst_ptr = NULL
178
+ cdef void * c_src_ptr = NULL
179
+ cdef DPCTLSyclEventRef ERef = NULL
180
+ cdef const unsigned char [::1 ] src_host_buf = None
181
+ cdef unsigned char [::1 ] dst_host_buf = None
182
+
183
+ if isinstance (src, _Memory):
184
+ c_src_ptr = < void * > (< _Memory> src).memory_ptr
185
+ elif _is_buffer(src):
186
+ src_host_buf = src
187
+ c_src_ptr = < void * > & src_host_buf[0 ]
188
+ else :
189
+ raise TypeError (
190
+ " Parameter `src` should have either type "
191
+ " `dpctl.memory._Memory` or a type that "
192
+ " supports Python buffer protocol"
193
+ )
194
+
195
+ if isinstance (dst, _Memory):
196
+ c_dst_ptr = < void * > (< _Memory> dst).memory_ptr
197
+ elif _is_buffer(dst):
198
+ dst_host_buf = dst
199
+ c_dst_ptr = < void * > & dst_host_buf[0 ]
200
+ else :
201
+ raise TypeError (
202
+ " Parameter `dst` should have either type "
203
+ " `dpctl.memory._Memory` or a type that "
204
+ " supports Python buffer protocol"
205
+ )
206
+
207
+ if dep_events_count == 0 or dep_events is NULL :
208
+ ERef = DPCTLQueue_Memcpy(q._queue_ref, c_dst_ptr, c_src_ptr, byte_count)
209
+ else :
210
+ ERef = DPCTLQueue_MemcpyWithEvents(
211
+ q._queue_ref,
212
+ c_dst_ptr,
213
+ c_src_ptr,
214
+ byte_count,
215
+ dep_events,
216
+ dep_events_count
217
+ )
218
+ return ERef
219
+
220
+
163
221
cdef class _SyclQueue:
164
222
""" Barebone data owner class used by SyclQueue.
165
223
"""
@@ -925,44 +983,44 @@ cdef class SyclQueue(_SyclQueue):
925
983
with nogil: DPCTLQueue_Wait(self ._queue_ref)
926
984
927
985
cpdef memcpy(self , dest, src, size_t count):
928
- cdef void * c_dest
929
- cdef void * c_src
986
+ """ Copy memory from `src` to `dst`"""
930
987
cdef DPCTLSyclEventRef ERef = NULL
931
988
932
- if isinstance (dest, _Memory):
933
- c_dest = < void * > (< _Memory> dest).memory_ptr
934
- else :
935
- raise TypeError (" Parameter `dest` should have type _Memory." )
936
-
937
- if isinstance (src, _Memory):
938
- c_src = < void * > (< _Memory> src).memory_ptr
939
- else :
940
- raise TypeError (" Parameter `src` should have type _Memory." )
941
-
942
- ERef = DPCTLQueue_Memcpy(self ._queue_ref, c_dest, c_src, count)
989
+ ERef = _memcpy_impl(< SyclQueue> self , dest, src, count, NULL , 0 )
943
990
if (ERef is NULL ):
944
991
raise RuntimeError (
945
992
" SyclQueue.memcpy operation encountered an error"
946
993
)
947
994
with nogil: DPCTLEvent_Wait(ERef)
948
995
DPCTLEvent_Delete(ERef)
949
996
950
- cpdef SyclEvent memcpy_async(self , dest, src, size_t count):
951
- cdef void * c_dest
952
- cdef void * c_src
997
+ cpdef SyclEvent memcpy_async(self , dest, src, size_t count, list dEvents = None ):
998
+ """ Copy memory from `src` to `dst`"""
953
999
cdef DPCTLSyclEventRef ERef = NULL
1000
+ cdef DPCTLSyclEventRef * depEvents = NULL
1001
+ cdef size_t nDE = 0
954
1002
955
- if isinstance (dest, _Memory):
956
- c_dest = < void * > (< _Memory> dest).memory_ptr
957
- else :
958
- raise TypeError (" Parameter `dest` should have type _Memory." )
959
-
960
- if isinstance (src, _Memory):
961
- c_src = < void * > (< _Memory> src).memory_ptr
1003
+ if dEvents is None :
1004
+ ERef = _memcpy_impl(< SyclQueue> self , dest, src, count, NULL , 0 )
962
1005
else :
963
- raise TypeError (" Parameter `src` should have type _Memory." )
1006
+ nDE = len (dEvents)
1007
+ depEvents = (
1008
+ < DPCTLSyclEventRef* > malloc(nDE* sizeof(DPCTLSyclEventRef))
1009
+ )
1010
+ if depEvents is NULL :
1011
+ raise MemoryError ()
1012
+ else :
1013
+ for idx, de in enumerate (dEvents):
1014
+ if isinstance (de, SyclEvent):
1015
+ depEvents[idx] = (< SyclEvent> de).get_event_ref()
1016
+ else :
1017
+ free(depEvents)
1018
+ raise TypeError (
1019
+ " A sequence of dpctl.SyclEvent is expected"
1020
+ )
1021
+ ERef = _memcpy_impl(self , dest, src, count, depEvents, nDE)
1022
+ free(depEvents)
964
1023
965
- ERef = DPCTLQueue_Memcpy(self ._queue_ref, c_dest, c_src, count)
966
1024
if (ERef is NULL ):
967
1025
raise RuntimeError (
968
1026
" SyclQueue.memcpy operation encountered an error"
0 commit comments