Skip to content

Commit be5fb99

Browse files
Added SyclQueue._submit_keep_args_alive method
Usage: q = dpctl.SyclQueue() ... e = q.submit(krn, args, ranges) ht_e = q._submit_keep_args_alive(args, [e]) .... ht_e.wait()
1 parent 1d57614 commit be5fb99

File tree

3 files changed

+132
-33
lines changed

3 files changed

+132
-33
lines changed

dpctl/_host_task_util.hpp

Lines changed: 29 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
//
33
// Data Parallel Control (dpctl)
44
//
5-
// Copyright 2020-2022 Intel Corporation
5+
// Copyright 2020-2023 Intel Corporation
66
//
77
// Licensed under the Apache License, Version 2.0 (the "License");
88
// you may not use this file except in compliance with the License.
@@ -31,28 +31,27 @@
3131

3232
#include "Python.h"
3333
#include "syclinterface/dpctl_data_types.h"
34+
#include "syclinterface/dpctl_sycl_type_casters.hpp"
3435
#include <CL/sycl.hpp>
3536

36-
int async_dec_ref(DPCTLSyclQueueRef QRef,
37-
PyObject **obj_array,
38-
size_t obj_array_size,
39-
DPCTLSyclEventRef *ERefs,
40-
size_t nERefs)
37+
DPCTLSyclEventRef async_dec_ref(DPCTLSyclQueueRef QRef,
38+
PyObject **obj_array,
39+
size_t obj_array_size,
40+
DPCTLSyclEventRef *depERefs,
41+
size_t nDepERefs,
42+
int *status)
4143
{
44+
using dpctl::syclinterface::unwrap;
45+
using dpctl::syclinterface::wrap;
4246

43-
sycl::queue *q = reinterpret_cast<sycl::queue *>(QRef);
47+
sycl::queue *q = unwrap<sycl::queue>(QRef);
4448

45-
std::vector<PyObject *> obj_vec;
46-
obj_vec.reserve(obj_array_size);
47-
for (size_t obj_id = 0; obj_id < obj_array_size; ++obj_id) {
48-
obj_vec.push_back(obj_array[obj_id]);
49-
}
49+
std::vector<PyObject *> obj_vec(obj_array, obj_array + obj_array_size);
5050

5151
try {
52-
q->submit([&](sycl::handler &cgh) {
53-
for (size_t ev_id = 0; ev_id < nERefs; ++ev_id) {
54-
cgh.depends_on(
55-
*(reinterpret_cast<sycl::event *>(ERefs[ev_id])));
52+
sycl::event ht_ev = q->submit([&](sycl::handler &cgh) {
53+
for (size_t ev_id = 0; ev_id < nDepERefs; ++ev_id) {
54+
cgh.depends_on(*(unwrap<sycl::event>(depERefs[ev_id])));
5655
}
5756
cgh.host_task([obj_array_size, obj_vec]() {
5857
// if the main thread has not finilized the interpreter yet
@@ -66,9 +65,21 @@ int async_dec_ref(DPCTLSyclQueueRef QRef,
6665
}
6766
});
6867
});
68+
69+
constexpr int result_ok = 0;
70+
71+
*status = result_ok;
72+
auto e_ptr = new sycl::event(ht_ev);
73+
return wrap<sycl::event>(e_ptr);
6974
} catch (const std::exception &e) {
70-
return 1;
75+
constexpr int result_std_exception = 1;
76+
77+
*status = result_std_exception;
78+
return nullptr;
7179
}
7280

73-
return 0;
81+
constexpr int result_other_abnormal = 2;
82+
83+
*status = result_other_abnormal;
84+
return nullptr;
7485
}

dpctl/_sycl_queue.pxd

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,11 @@ cdef public api class SyclQueue (_SyclQueue) [
7070
cpdef SyclContext get_sycl_context(self)
7171
cpdef SyclDevice get_sycl_device(self)
7272
cdef DPCTLSyclQueueRef get_queue_ref(self)
73+
cpdef SyclEvent _submit_keep_args_alive(
74+
self,
75+
object args,
76+
list dEvents
77+
)
7378
cpdef SyclEvent submit(
7479
self,
7580
SyclKernel kernel,

dpctl/_sycl_queue.pyx

Lines changed: 98 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ import logging
7272

7373

7474
cdef extern from "_host_task_util.hpp":
75-
int async_dec_ref(DPCTLSyclQueueRef, PyObject **, size_t, DPCTLSyclEventRef *, size_t) nogil
75+
DPCTLSyclEventRef async_dec_ref(DPCTLSyclQueueRef, PyObject **, size_t, DPCTLSyclEventRef *, size_t, int *) nogil
7676

7777

7878
__all__ = [
@@ -703,6 +703,79 @@ cdef class SyclQueue(_SyclQueue):
703703
"""
704704
return <size_t>self._queue_ref
705705

706+
707+
cpdef SyclEvent _submit_keep_args_alive(
708+
self,
709+
object args,
710+
list dEvents
711+
):
712+
""" SyclQueue._submit_keep_args_alive(args, events)
713+
714+
Keeps objects in `args` alive until tasks associated with events
715+
complete.
716+
717+
Args:
718+
args(object): Python object to keep alive.
719+
Typically a tuple with arguments to offloaded tasks
720+
events(Tuple[dpctl.SyclEvent]): Gating events
721+
The list or tuple of events associated with tasks
722+
working on Python objects collected in `args`.
723+
Returns:
724+
dpctl.SyclEvent
725+
The event associated with the submission of host task.
726+
727+
Increments reference count of `args` and schedules asynchronous
728+
``host_task`` to decrement the count once dependent events are
729+
complete.
730+
731+
N.B.: The `host_task` attempts to acquire Python GIL, and it is
732+
known to be unsafe during interpreter shudown sequence. It is
733+
thus strongly advised to ensure that all submitted `host_task`
734+
complete before the end of the Python script.
735+
"""
736+
cdef size_t nDE = len(dEvents)
737+
cdef DPCTLSyclEventRef *depEvents = NULL
738+
cdef PyObject *args_raw = NULL
739+
cdef DPCTLSyclEventRef htERef = NULL
740+
cdef int status = -1
741+
742+
# Create the array of dependent events if any
743+
if nDE > 0:
744+
depEvents = (
745+
<DPCTLSyclEventRef*>malloc(nDE*sizeof(DPCTLSyclEventRef))
746+
)
747+
if not depEvents:
748+
raise MemoryError()
749+
else:
750+
for idx, de in enumerate(dEvents):
751+
if isinstance(de, SyclEvent):
752+
depEvents[idx] = (<SyclEvent>de).get_event_ref()
753+
else:
754+
free(depEvents)
755+
raise TypeError(
756+
"A sequence of dpctl.SyclEvent is expected"
757+
)
758+
759+
# increment reference counts to list of arguments
760+
Py_INCREF(args)
761+
762+
# schedule decrement
763+
args_raw = <PyObject *>args
764+
765+
htERef = async_dec_ref(
766+
self.get_queue_ref(),
767+
&args_raw, 1,
768+
depEvents, nDE, &status
769+
)
770+
771+
free(depEvents)
772+
if (status != 0):
773+
with nogil: DPCTLEvent_Wait(htERef)
774+
raise RuntimeError("Could not submit keep_args_alive host_task")
775+
776+
return SyclEvent._create(htERef)
777+
778+
706779
cpdef SyclEvent submit(
707780
self,
708781
SyclKernel kernel,
@@ -715,13 +788,14 @@ cdef class SyclQueue(_SyclQueue):
715788
cdef _arg_data_type *kargty = NULL
716789
cdef DPCTLSyclEventRef *depEvents = NULL
717790
cdef DPCTLSyclEventRef Eref = NULL
791+
cdef DPCTLSyclEventRef htEref = NULL
718792
cdef int ret = 0
719793
cdef size_t gRange[3]
720794
cdef size_t lRange[3]
721795
cdef size_t nGS = len(gS)
722796
cdef size_t nLS = len(lS) if lS is not None else 0
723797
cdef size_t nDE = len(dEvents) if dEvents is not None else 0
724-
cdef PyObject **arg_objects = NULL
798+
cdef PyObject *args_raw = NULL
725799
cdef ssize_t i = 0
726800

727801
# Allocate the arrays to be sent to DPCTLQueue_Submit
@@ -745,7 +819,15 @@ cdef class SyclQueue(_SyclQueue):
745819
raise MemoryError()
746820
else:
747821
for idx, de in enumerate(dEvents):
748-
depEvents[idx] = (<SyclEvent>de).get_event_ref()
822+
if isinstance(de, SyclEvent):
823+
depEvents[idx] = (<SyclEvent>de).get_event_ref()
824+
else:
825+
free(kargs)
826+
free(kargty)
827+
free(depEvents)
828+
raise TypeError(
829+
"A sequence of dpctl.SyclEvent is expected"
830+
)
749831

750832
# populate the args and argstype arrays
751833
ret = self._populate_args(args, kargs, kargty)
@@ -823,22 +905,23 @@ cdef class SyclQueue(_SyclQueue):
823905
raise SyclKernelSubmitError(
824906
"Kernel submission to Sycl queue failed."
825907
)
826-
# increment reference counts to each argument
827-
arg_objects = <PyObject **>malloc(len(args) * sizeof(PyObject *))
828-
for i in range(len(args)):
829-
arg_objects[i] = <PyObject *>(args[i])
830-
Py_INCREF(<object> arg_objects[i])
908+
# increment reference counts to list of arguments
909+
Py_INCREF(args)
831910

832911
# schedule decrement
833-
if async_dec_ref(self.get_queue_ref(), arg_objects, len(args), &Eref, 1):
912+
args_raw = <PyObject *>args
913+
914+
ret = -1
915+
htERef = async_dec_ref(self.get_queue_ref(), &args_raw, 1, &Eref, 1, &ret)
916+
if ret:
834917
# async task submission failed, decrement ref counts and wait
835-
for i in range(len(args)):
836-
arg_objects[i] = <PyObject *>(args[i])
837-
Py_DECREF(<object> arg_objects[i])
838-
with nogil: DPCTLEvent_Wait(Eref)
918+
Py_DECREF(args)
919+
with nogil:
920+
DPCTLEvent_Wait(Eref)
921+
DPCTLEvent_Wait(htERef)
839922

840-
# free memory
841-
free(arg_objects)
923+
# we are not returning host-task event at the moment
924+
DPCTLEvent_Delete(htERef)
842925

843926
return SyclEvent._create(Eref)
844927

0 commit comments

Comments
 (0)