@@ -72,7 +72,7 @@ import logging
72
72
73
73
74
74
cdef extern from " _host_task_util.hpp" :
75
- int async_dec_ref(DPCTLSyclQueueRef, PyObject ** , size_t, DPCTLSyclEventRef * , size_t) nogil
75
+ DPCTLSyclEventRef async_dec_ref(DPCTLSyclQueueRef, PyObject ** , size_t, DPCTLSyclEventRef * , size_t, int * ) nogil
76
76
77
77
78
78
__all__ = [
@@ -703,6 +703,79 @@ cdef class SyclQueue(_SyclQueue):
703
703
"""
704
704
return < size_t> self ._queue_ref
705
705
706
+
707
+ cpdef SyclEvent _submit_keep_args_alive(
708
+ self ,
709
+ object args,
710
+ list dEvents
711
+ ):
712
+ """ SyclQueue._submit_keep_args_alive(args, events)
713
+
714
+ Keeps objects in `args` alive until tasks associated with events
715
+ complete.
716
+
717
+ Args:
718
+ args(object): Python object to keep alive.
719
+ Typically a tuple with arguments to offloaded tasks
720
+ events(Tuple[dpctl.SyclEvent]): Gating events
721
+ The list or tuple of events associated with tasks
722
+ working on Python objects collected in `args`.
723
+ Returns:
724
+ dpctl.SyclEvent
725
+ The event associated with the submission of host task.
726
+
727
+ Increments reference count of `args` and schedules asynchronous
728
+ ``host_task`` to decrement the count once dependent events are
729
+ complete.
730
+
731
+ N.B.: The `host_task` attempts to acquire Python GIL, and it is
732
+ known to be unsafe during interpreter shudown sequence. It is
733
+ thus strongly advised to ensure that all submitted `host_task`
734
+ complete before the end of the Python script.
735
+ """
736
+ cdef size_t nDE = len (dEvents)
737
+ cdef DPCTLSyclEventRef * depEvents = NULL
738
+ cdef PyObject * args_raw = NULL
739
+ cdef DPCTLSyclEventRef htERef = NULL
740
+ cdef int status = - 1
741
+
742
+ # Create the array of dependent events if any
743
+ if nDE > 0 :
744
+ depEvents = (
745
+ < DPCTLSyclEventRef* > malloc(nDE* sizeof(DPCTLSyclEventRef))
746
+ )
747
+ if not depEvents:
748
+ raise MemoryError ()
749
+ else :
750
+ for idx, de in enumerate (dEvents):
751
+ if isinstance (de, SyclEvent):
752
+ depEvents[idx] = (< SyclEvent> de).get_event_ref()
753
+ else :
754
+ free(depEvents)
755
+ raise TypeError (
756
+ " A sequence of dpctl.SyclEvent is expected"
757
+ )
758
+
759
+ # increment reference counts to list of arguments
760
+ Py_INCREF(args)
761
+
762
+ # schedule decrement
763
+ args_raw = < PyObject * > args
764
+
765
+ htERef = async_dec_ref(
766
+ self .get_queue_ref(),
767
+ & args_raw, 1 ,
768
+ depEvents, nDE, & status
769
+ )
770
+
771
+ free(depEvents)
772
+ if (status != 0 ):
773
+ with nogil: DPCTLEvent_Wait(htERef)
774
+ raise RuntimeError (" Could not submit keep_args_alive host_task" )
775
+
776
+ return SyclEvent._create(htERef)
777
+
778
+
706
779
cpdef SyclEvent submit(
707
780
self ,
708
781
SyclKernel kernel,
@@ -715,13 +788,14 @@ cdef class SyclQueue(_SyclQueue):
715
788
cdef _arg_data_type * kargty = NULL
716
789
cdef DPCTLSyclEventRef * depEvents = NULL
717
790
cdef DPCTLSyclEventRef Eref = NULL
791
+ cdef DPCTLSyclEventRef htEref = NULL
718
792
cdef int ret = 0
719
793
cdef size_t gRange[3 ]
720
794
cdef size_t lRange[3 ]
721
795
cdef size_t nGS = len (gS)
722
796
cdef size_t nLS = len (lS) if lS is not None else 0
723
797
cdef size_t nDE = len (dEvents) if dEvents is not None else 0
724
- cdef PyObject ** arg_objects = NULL
798
+ cdef PyObject * args_raw = NULL
725
799
cdef ssize_t i = 0
726
800
727
801
# Allocate the arrays to be sent to DPCTLQueue_Submit
@@ -745,7 +819,15 @@ cdef class SyclQueue(_SyclQueue):
745
819
raise MemoryError ()
746
820
else :
747
821
for idx, de in enumerate (dEvents):
748
- depEvents[idx] = (< SyclEvent> de).get_event_ref()
822
+ if isinstance (de, SyclEvent):
823
+ depEvents[idx] = (< SyclEvent> de).get_event_ref()
824
+ else :
825
+ free(kargs)
826
+ free(kargty)
827
+ free(depEvents)
828
+ raise TypeError (
829
+ " A sequence of dpctl.SyclEvent is expected"
830
+ )
749
831
750
832
# populate the args and argstype arrays
751
833
ret = self ._populate_args(args, kargs, kargty)
@@ -823,22 +905,23 @@ cdef class SyclQueue(_SyclQueue):
823
905
raise SyclKernelSubmitError(
824
906
" Kernel submission to Sycl queue failed."
825
907
)
826
- # increment reference counts to each argument
827
- arg_objects = < PyObject ** > malloc(len (args) * sizeof(PyObject * ))
828
- for i in range (len (args)):
829
- arg_objects[i] = < PyObject * > (args[i])
830
- Py_INCREF(< object > arg_objects[i])
908
+ # increment reference counts to list of arguments
909
+ Py_INCREF(args)
831
910
832
911
# schedule decrement
833
- if async_dec_ref(self .get_queue_ref(), arg_objects, len (args), & Eref, 1 ):
912
+ args_raw = < PyObject * > args
913
+
914
+ ret = - 1
915
+ htERef = async_dec_ref(self .get_queue_ref(), & args_raw, 1 , & Eref, 1 , & ret)
916
+ if ret:
834
917
# async task submission failed, decrement ref counts and wait
835
- for i in range ( len ( args)):
836
- arg_objects[i] = < PyObject * > (args[i])
837
- Py_DECREF( < object > arg_objects[i] )
838
- with nogil: DPCTLEvent_Wait(Eref )
918
+ Py_DECREF( args)
919
+ with nogil:
920
+ DPCTLEvent_Wait(Eref )
921
+ DPCTLEvent_Wait(htERef )
839
922
840
- # free memory
841
- free(arg_objects )
923
+ # we are not returning host-task event at the moment
924
+ DPCTLEvent_Delete(htERef )
842
925
843
926
return SyclEvent._create(Eref)
844
927
0 commit comments