Skip to content

Commit f77f7a4

Browse files
Support stream keyword in usm_ndarray.to_device per array API
1 parent ebd1faf commit f77f7a4

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

dpctl/tensor/_usmarray.pyx

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -816,7 +816,7 @@ cdef class usm_ndarray:
816816
return _take_multi_index(res, adv_ind, adv_ind_start_p)
817817

818818

819-
def to_device(self, target):
819+
def to_device(self, target, stream=None):
820820
""" to_device(target_device)
821821
822822
Transfers this array to specified target device.
@@ -856,6 +856,14 @@ cdef class usm_ndarray:
856856
cdef c_dpctl.DPCTLSyclQueueRef QRef = NULL
857857
cdef c_dpmem._Memory arr_buf
858858
d = Device.create_device(target)
859+
860+
if (stream is None or type(stream) is not dpctl.SyclQueue or
861+
stream == self.sycl_queue):
862+
pass
863+
else:
864+
ev = self.sycl_queue.submit_barrier()
865+
stream.submit_barrier(dependent_events=[ev])
866+
859867
if (d.sycl_context == self.sycl_context):
860868
arr_buf = <c_dpmem._Memory> self.usm_data
861869
QRef = (<c_dpctl.SyclQueue> d.sycl_queue).get_queue_ref()

0 commit comments

Comments
 (0)