@@ -73,7 +73,7 @@ import logging
73
73
74
74
75
75
cdef extern from " _host_task_util.hpp" :
76
- int async_dec_ref(DPCTLSyclQueueRef, PyObject ** , size_t, DPCTLSyclEventRef * , size_t) nogil
76
+ DPCTLSyclEventRef async_dec_ref(DPCTLSyclQueueRef, PyObject ** , size_t, DPCTLSyclEventRef * , size_t, int * ) nogil
77
77
78
78
79
79
__all__ = [
@@ -716,6 +716,79 @@ cdef class SyclQueue(_SyclQueue):
716
716
"""
717
717
return < size_t> self ._queue_ref
718
718
719
+
720
+ cpdef SyclEvent _submit_keep_args_alive(
721
+ self ,
722
+ object args,
723
+ list dEvents
724
+ ):
725
+ """ SyclQueue._submit_keep_args_alive(args, events)
726
+
727
+ Keeps objects in `args` alive until tasks associated with events
728
+ complete.
729
+
730
+ Args:
731
+ args(object): Python object to keep alive.
732
+ Typically a tuple with arguments to offloaded tasks
733
+ events(Tuple[dpctl.SyclEvent]): Gating events
734
+ The list or tuple of events associated with tasks
735
+ working on Python objects collected in `args`.
736
+ Returns:
737
+ dpctl.SyclEvent
738
+ The event associated with the submission of host task.
739
+
740
+ Increments reference count of `args` and schedules asynchronous
741
+ ``host_task`` to decrement the count once dependent events are
742
+ complete.
743
+
744
+ N.B.: The `host_task` attempts to acquire Python GIL, and it is
745
+ known to be unsafe during interpreter shudown sequence. It is
746
+ thus strongly advised to ensure that all submitted `host_task`
747
+ complete before the end of the Python script.
748
+ """
749
+ cdef size_t nDE = len (dEvents)
750
+ cdef DPCTLSyclEventRef * depEvents = NULL
751
+ cdef PyObject * args_raw = NULL
752
+ cdef DPCTLSyclEventRef htERef = NULL
753
+ cdef int status = - 1
754
+
755
+ # Create the array of dependent events if any
756
+ if nDE > 0 :
757
+ depEvents = (
758
+ < DPCTLSyclEventRef* > malloc(nDE* sizeof(DPCTLSyclEventRef))
759
+ )
760
+ if not depEvents:
761
+ raise MemoryError ()
762
+ else :
763
+ for idx, de in enumerate (dEvents):
764
+ if isinstance (de, SyclEvent):
765
+ depEvents[idx] = (< SyclEvent> de).get_event_ref()
766
+ else :
767
+ free(depEvents)
768
+ raise TypeError (
769
+ " A sequence of dpctl.SyclEvent is expected"
770
+ )
771
+
772
+ # increment reference counts to list of arguments
773
+ Py_INCREF(args)
774
+
775
+ # schedule decrement
776
+ args_raw = < PyObject * > args
777
+
778
+ htERef = async_dec_ref(
779
+ self .get_queue_ref(),
780
+ & args_raw, 1 ,
781
+ depEvents, nDE, & status
782
+ )
783
+
784
+ free(depEvents)
785
+ if (status != 0 ):
786
+ with nogil: DPCTLEvent_Wait(htERef)
787
+ raise RuntimeError (" Could not submit keep_args_alive host_task" )
788
+
789
+ return SyclEvent._create(htERef)
790
+
791
+
719
792
cpdef SyclEvent submit(
720
793
self ,
721
794
SyclKernel kernel,
@@ -728,13 +801,14 @@ cdef class SyclQueue(_SyclQueue):
728
801
cdef _arg_data_type * kargty = NULL
729
802
cdef DPCTLSyclEventRef * depEvents = NULL
730
803
cdef DPCTLSyclEventRef Eref = NULL
804
+ cdef DPCTLSyclEventRef htEref = NULL
731
805
cdef int ret = 0
732
806
cdef size_t gRange[3 ]
733
807
cdef size_t lRange[3 ]
734
808
cdef size_t nGS = len (gS)
735
809
cdef size_t nLS = len (lS) if lS is not None else 0
736
810
cdef size_t nDE = len (dEvents) if dEvents is not None else 0
737
- cdef PyObject ** arg_objects = NULL
811
+ cdef PyObject * args_raw = NULL
738
812
cdef ssize_t i = 0
739
813
740
814
# Allocate the arrays to be sent to DPCTLQueue_Submit
@@ -758,7 +832,15 @@ cdef class SyclQueue(_SyclQueue):
758
832
raise MemoryError ()
759
833
else :
760
834
for idx, de in enumerate (dEvents):
761
- depEvents[idx] = (< SyclEvent> de).get_event_ref()
835
+ if isinstance (de, SyclEvent):
836
+ depEvents[idx] = (< SyclEvent> de).get_event_ref()
837
+ else :
838
+ free(kargs)
839
+ free(kargty)
840
+ free(depEvents)
841
+ raise TypeError (
842
+ " A sequence of dpctl.SyclEvent is expected"
843
+ )
762
844
763
845
# populate the args and argstype arrays
764
846
ret = self ._populate_args(args, kargs, kargty)
@@ -836,22 +918,23 @@ cdef class SyclQueue(_SyclQueue):
836
918
raise SyclKernelSubmitError(
837
919
" Kernel submission to Sycl queue failed."
838
920
)
839
- # increment reference counts to each argument
840
- arg_objects = < PyObject ** > malloc(len (args) * sizeof(PyObject * ))
841
- for i in range (len (args)):
842
- arg_objects[i] = < PyObject * > (args[i])
843
- Py_INCREF(< object > arg_objects[i])
921
+ # increment reference counts to list of arguments
922
+ Py_INCREF(args)
844
923
845
924
# schedule decrement
846
- if async_dec_ref(self .get_queue_ref(), arg_objects, len (args), & Eref, 1 ):
925
+ args_raw = < PyObject * > args
926
+
927
+ ret = - 1
928
+ htERef = async_dec_ref(self .get_queue_ref(), & args_raw, 1 , & Eref, 1 , & ret)
929
+ if ret:
847
930
# async task submission failed, decrement ref counts and wait
848
- for i in range ( len ( args)):
849
- arg_objects[i] = < PyObject * > (args[i])
850
- Py_DECREF( < object > arg_objects[i] )
851
- with nogil: DPCTLEvent_Wait(Eref )
931
+ Py_DECREF( args)
932
+ with nogil:
933
+ DPCTLEvent_Wait(Eref )
934
+ DPCTLEvent_Wait(htERef )
852
935
853
- # free memory
854
- free(arg_objects )
936
+ # we are not returning host-task event at the moment
937
+ DPCTLEvent_Delete(htERef )
855
938
856
939
return SyclEvent._create(Eref)
857
940
0 commit comments