@@ -45,6 +45,7 @@ from ._backend cimport ( # noqa: E211
45
45
DPCTLQueue_IsInOrder,
46
46
DPCTLQueue_MemAdvise,
47
47
DPCTLQueue_Memcpy,
48
+ DPCTLQueue_MemcpyWithEvents,
48
49
DPCTLQueue_Prefetch,
49
50
DPCTLQueue_SubmitBarrierForEvents,
50
51
DPCTLQueue_SubmitNDRange,
@@ -65,6 +66,7 @@ import ctypes
65
66
from .enum_types import backend_type
66
67
67
68
from cpython cimport pycapsule
69
+ from cpython.buffer cimport PyObject_CheckBuffer
68
70
from cpython.ref cimport Py_DECREF, Py_INCREF, PyObject
69
71
from libc.stdlib cimport free, malloc
70
72
@@ -173,6 +175,62 @@ cdef void _queue_capsule_deleter(object o) noexcept:
173
175
DPCTLQueue_Delete(QRef)
174
176
175
177
178
+ cdef bint _is_buffer(object o):
179
+ return PyObject_CheckBuffer(o)
180
+
181
+
182
+ cdef DPCTLSyclEventRef _memcpy_impl(
183
+ SyclQueue q,
184
+ object dst,
185
+ object src,
186
+ size_t byte_count,
187
+ DPCTLSyclEventRef * dep_events,
188
+ size_t dep_events_count
189
+ ):
190
+ cdef void * c_dst_ptr = NULL
191
+ cdef void * c_src_ptr = NULL
192
+ cdef DPCTLSyclEventRef ERef = NULL
193
+ cdef const unsigned char [::1 ] src_host_buf = None
194
+ cdef unsigned char [::1 ] dst_host_buf = None
195
+
196
+ if isinstance (src, _Memory):
197
+ c_src_ptr = < void * > (< _Memory> src).memory_ptr
198
+ elif _is_buffer(src):
199
+ src_host_buf = src
200
+ c_src_ptr = < void * > & src_host_buf[0 ]
201
+ else :
202
+ raise TypeError (
203
+ " Parameter `src` should have either type "
204
+ " `dpctl.memory._Memory` or a type that "
205
+ " supports Python buffer protocol"
206
+ )
207
+
208
+ if isinstance (dst, _Memory):
209
+ c_dst_ptr = < void * > (< _Memory> dst).memory_ptr
210
+ elif _is_buffer(dst):
211
+ dst_host_buf = dst
212
+ c_dst_ptr = < void * > & dst_host_buf[0 ]
213
+ else :
214
+ raise TypeError (
215
+ " Parameter `dst` should have either type "
216
+ " `dpctl.memory._Memory` or a type that "
217
+ " supports Python buffer protocol"
218
+ )
219
+
220
+ if dep_events_count == 0 or dep_events is NULL :
221
+ ERef = DPCTLQueue_Memcpy(q._queue_ref, c_dst_ptr, c_src_ptr, byte_count)
222
+ else :
223
+ ERef = DPCTLQueue_MemcpyWithEvents(
224
+ q._queue_ref,
225
+ c_dst_ptr,
226
+ c_src_ptr,
227
+ byte_count,
228
+ dep_events,
229
+ dep_events_count
230
+ )
231
+ return ERef
232
+
233
+
176
234
cdef class _SyclQueue:
177
235
""" Barebone data owner class used by SyclQueue.
178
236
"""
@@ -938,44 +996,44 @@ cdef class SyclQueue(_SyclQueue):
938
996
with nogil: DPCTLQueue_Wait(self ._queue_ref)
939
997
940
998
cpdef memcpy(self , dest, src, size_t count):
941
- cdef void * c_dest
942
- cdef void * c_src
999
+ """ Copy memory from `src` to `dst`"""
943
1000
cdef DPCTLSyclEventRef ERef = NULL
944
1001
945
- if isinstance (dest, _Memory):
946
- c_dest = < void * > (< _Memory> dest).memory_ptr
947
- else :
948
- raise TypeError (" Parameter `dest` should have type _Memory." )
949
-
950
- if isinstance (src, _Memory):
951
- c_src = < void * > (< _Memory> src).memory_ptr
952
- else :
953
- raise TypeError (" Parameter `src` should have type _Memory." )
954
-
955
- ERef = DPCTLQueue_Memcpy(self ._queue_ref, c_dest, c_src, count)
1002
+ ERef = _memcpy_impl(< SyclQueue> self , dest, src, count, NULL , 0 )
956
1003
if (ERef is NULL ):
957
1004
raise RuntimeError (
958
1005
" SyclQueue.memcpy operation encountered an error"
959
1006
)
960
1007
with nogil: DPCTLEvent_Wait(ERef)
961
1008
DPCTLEvent_Delete(ERef)
962
1009
963
- cpdef SyclEvent memcpy_async(self , dest, src, size_t count):
964
- cdef void * c_dest
965
- cdef void * c_src
1010
+ cpdef SyclEvent memcpy_async(self , dest, src, size_t count, list dEvents = None ):
1011
+ """ Copy memory from `src` to `dst`"""
966
1012
cdef DPCTLSyclEventRef ERef = NULL
1013
+ cdef DPCTLSyclEventRef * depEvents = NULL
1014
+ cdef size_t nDE = 0
967
1015
968
- if isinstance (dest, _Memory):
969
- c_dest = < void * > (< _Memory> dest).memory_ptr
970
- else :
971
- raise TypeError (" Parameter `dest` should have type _Memory." )
972
-
973
- if isinstance (src, _Memory):
974
- c_src = < void * > (< _Memory> src).memory_ptr
1016
+ if dEvents is None :
1017
+ ERef = _memcpy_impl(< SyclQueue> self , dest, src, count, NULL , 0 )
975
1018
else :
976
- raise TypeError (" Parameter `src` should have type _Memory." )
1019
+ nDE = len (dEvents)
1020
+ depEvents = (
1021
+ < DPCTLSyclEventRef* > malloc(nDE* sizeof(DPCTLSyclEventRef))
1022
+ )
1023
+ if depEvents is NULL :
1024
+ raise MemoryError ()
1025
+ else :
1026
+ for idx, de in enumerate (dEvents):
1027
+ if isinstance (de, SyclEvent):
1028
+ depEvents[idx] = (< SyclEvent> de).get_event_ref()
1029
+ else :
1030
+ free(depEvents)
1031
+ raise TypeError (
1032
+ " A sequence of dpctl.SyclEvent is expected"
1033
+ )
1034
+ ERef = _memcpy_impl(self , dest, src, count, depEvents, nDE)
1035
+ free(depEvents)
977
1036
978
- ERef = DPCTLQueue_Memcpy(self ._queue_ref, c_dest, c_src, count)
979
1037
if (ERef is NULL ):
980
1038
raise RuntimeError (
981
1039
" SyclQueue.memcpy operation encountered an error"
0 commit comments