Skip to content

Commit beef6a2

Browse files
Added SyclQueue._submit_keep_args_alive method
Usage: q = dpctl.SyclQueue() ... e = q.submit(krn, args, ranges) ht_e = q._submit_keep_args_alive(args, [e]) .... ht_e.wait()
1 parent 83fff33 commit beef6a2

File tree

3 files changed

+132
-33
lines changed

3 files changed

+132
-33
lines changed

dpctl/_host_task_util.hpp

Lines changed: 29 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
//
33
// Data Parallel Control (dpctl)
44
//
5-
// Copyright 2020-2022 Intel Corporation
5+
// Copyright 2020-2023 Intel Corporation
66
//
77
// Licensed under the Apache License, Version 2.0 (the "License");
88
// you may not use this file except in compliance with the License.
@@ -31,28 +31,27 @@
3131

3232
#include "Python.h"
3333
#include "syclinterface/dpctl_data_types.h"
34+
#include "syclinterface/dpctl_sycl_type_casters.hpp"
3435
#include <CL/sycl.hpp>
3536

36-
int async_dec_ref(DPCTLSyclQueueRef QRef,
37-
PyObject **obj_array,
38-
size_t obj_array_size,
39-
DPCTLSyclEventRef *ERefs,
40-
size_t nERefs)
37+
DPCTLSyclEventRef async_dec_ref(DPCTLSyclQueueRef QRef,
38+
PyObject **obj_array,
39+
size_t obj_array_size,
40+
DPCTLSyclEventRef *depERefs,
41+
size_t nDepERefs,
42+
int *status)
4143
{
44+
using dpctl::syclinterface::unwrap;
45+
using dpctl::syclinterface::wrap;
4246

43-
sycl::queue *q = reinterpret_cast<sycl::queue *>(QRef);
47+
sycl::queue *q = unwrap<sycl::queue>(QRef);
4448

45-
std::vector<PyObject *> obj_vec;
46-
obj_vec.reserve(obj_array_size);
47-
for (size_t obj_id = 0; obj_id < obj_array_size; ++obj_id) {
48-
obj_vec.push_back(obj_array[obj_id]);
49-
}
49+
std::vector<PyObject *> obj_vec(obj_array, obj_array + obj_array_size);
5050

5151
try {
52-
q->submit([&](sycl::handler &cgh) {
53-
for (size_t ev_id = 0; ev_id < nERefs; ++ev_id) {
54-
cgh.depends_on(
55-
*(reinterpret_cast<sycl::event *>(ERefs[ev_id])));
52+
sycl::event ht_ev = q->submit([&](sycl::handler &cgh) {
53+
for (size_t ev_id = 0; ev_id < nDepERefs; ++ev_id) {
54+
cgh.depends_on(*(unwrap<sycl::event>(depERefs[ev_id])));
5655
}
5756
cgh.host_task([obj_array_size, obj_vec]() {
5857
// if the main thread has not finilized the interpreter yet
@@ -66,9 +65,21 @@ int async_dec_ref(DPCTLSyclQueueRef QRef,
6665
}
6766
});
6867
});
68+
69+
constexpr int result_ok = 0;
70+
71+
*status = result_ok;
72+
auto e_ptr = new sycl::event(ht_ev);
73+
return wrap<sycl::event>(e_ptr);
6974
} catch (const std::exception &e) {
70-
return 1;
75+
constexpr int result_std_exception = 1;
76+
77+
*status = result_std_exception;
78+
return nullptr;
7179
}
7280

73-
return 0;
81+
constexpr int result_other_abnormal = 2;
82+
83+
*status = result_other_abnormal;
84+
return nullptr;
7485
}

dpctl/_sycl_queue.pxd

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,11 @@ cdef public api class SyclQueue (_SyclQueue) [
7272
cpdef SyclContext get_sycl_context(self)
7373
cpdef SyclDevice get_sycl_device(self)
7474
cdef DPCTLSyclQueueRef get_queue_ref(self)
75+
cpdef SyclEvent _submit_keep_args_alive(
76+
self,
77+
object args,
78+
list dEvents
79+
)
7580
cpdef SyclEvent submit(
7681
self,
7782
SyclKernel kernel,

dpctl/_sycl_queue.pyx

Lines changed: 98 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ import logging
7373

7474

7575
cdef extern from "_host_task_util.hpp":
76-
int async_dec_ref(DPCTLSyclQueueRef, PyObject **, size_t, DPCTLSyclEventRef *, size_t) nogil
76+
DPCTLSyclEventRef async_dec_ref(DPCTLSyclQueueRef, PyObject **, size_t, DPCTLSyclEventRef *, size_t, int *) nogil
7777

7878

7979
__all__ = [
@@ -716,6 +716,79 @@ cdef class SyclQueue(_SyclQueue):
716716
"""
717717
return <size_t>self._queue_ref
718718

719+
720+
cpdef SyclEvent _submit_keep_args_alive(
721+
self,
722+
object args,
723+
list dEvents
724+
):
725+
""" SyclQueue._submit_keep_args_alive(args, events)
726+
727+
Keeps objects in `args` alive until tasks associated with events
728+
complete.
729+
730+
Args:
731+
args(object): Python object to keep alive.
732+
Typically a tuple with arguments to offloaded tasks
733+
events(Tuple[dpctl.SyclEvent]): Gating events
734+
The list or tuple of events associated with tasks
735+
working on Python objects collected in `args`.
736+
Returns:
737+
dpctl.SyclEvent
738+
The event associated with the submission of host task.
739+
740+
Increments reference count of `args` and schedules asynchronous
741+
``host_task`` to decrement the count once dependent events are
742+
complete.
743+
744+
N.B.: The `host_task` attempts to acquire Python GIL, and it is
745+
known to be unsafe during interpreter shudown sequence. It is
746+
thus strongly advised to ensure that all submitted `host_task`
747+
complete before the end of the Python script.
748+
"""
749+
cdef size_t nDE = len(dEvents)
750+
cdef DPCTLSyclEventRef *depEvents = NULL
751+
cdef PyObject *args_raw = NULL
752+
cdef DPCTLSyclEventRef htERef = NULL
753+
cdef int status = -1
754+
755+
# Create the array of dependent events if any
756+
if nDE > 0:
757+
depEvents = (
758+
<DPCTLSyclEventRef*>malloc(nDE*sizeof(DPCTLSyclEventRef))
759+
)
760+
if not depEvents:
761+
raise MemoryError()
762+
else:
763+
for idx, de in enumerate(dEvents):
764+
if isinstance(de, SyclEvent):
765+
depEvents[idx] = (<SyclEvent>de).get_event_ref()
766+
else:
767+
free(depEvents)
768+
raise TypeError(
769+
"A sequence of dpctl.SyclEvent is expected"
770+
)
771+
772+
# increment reference counts to list of arguments
773+
Py_INCREF(args)
774+
775+
# schedule decrement
776+
args_raw = <PyObject *>args
777+
778+
htERef = async_dec_ref(
779+
self.get_queue_ref(),
780+
&args_raw, 1,
781+
depEvents, nDE, &status
782+
)
783+
784+
free(depEvents)
785+
if (status != 0):
786+
with nogil: DPCTLEvent_Wait(htERef)
787+
raise RuntimeError("Could not submit keep_args_alive host_task")
788+
789+
return SyclEvent._create(htERef)
790+
791+
719792
cpdef SyclEvent submit(
720793
self,
721794
SyclKernel kernel,
@@ -728,13 +801,14 @@ cdef class SyclQueue(_SyclQueue):
728801
cdef _arg_data_type *kargty = NULL
729802
cdef DPCTLSyclEventRef *depEvents = NULL
730803
cdef DPCTLSyclEventRef Eref = NULL
804+
cdef DPCTLSyclEventRef htEref = NULL
731805
cdef int ret = 0
732806
cdef size_t gRange[3]
733807
cdef size_t lRange[3]
734808
cdef size_t nGS = len(gS)
735809
cdef size_t nLS = len(lS) if lS is not None else 0
736810
cdef size_t nDE = len(dEvents) if dEvents is not None else 0
737-
cdef PyObject **arg_objects = NULL
811+
cdef PyObject *args_raw = NULL
738812
cdef ssize_t i = 0
739813

740814
# Allocate the arrays to be sent to DPCTLQueue_Submit
@@ -758,7 +832,15 @@ cdef class SyclQueue(_SyclQueue):
758832
raise MemoryError()
759833
else:
760834
for idx, de in enumerate(dEvents):
761-
depEvents[idx] = (<SyclEvent>de).get_event_ref()
835+
if isinstance(de, SyclEvent):
836+
depEvents[idx] = (<SyclEvent>de).get_event_ref()
837+
else:
838+
free(kargs)
839+
free(kargty)
840+
free(depEvents)
841+
raise TypeError(
842+
"A sequence of dpctl.SyclEvent is expected"
843+
)
762844

763845
# populate the args and argstype arrays
764846
ret = self._populate_args(args, kargs, kargty)
@@ -836,22 +918,23 @@ cdef class SyclQueue(_SyclQueue):
836918
raise SyclKernelSubmitError(
837919
"Kernel submission to Sycl queue failed."
838920
)
839-
# increment reference counts to each argument
840-
arg_objects = <PyObject **>malloc(len(args) * sizeof(PyObject *))
841-
for i in range(len(args)):
842-
arg_objects[i] = <PyObject *>(args[i])
843-
Py_INCREF(<object> arg_objects[i])
921+
# increment reference counts to list of arguments
922+
Py_INCREF(args)
844923

845924
# schedule decrement
846-
if async_dec_ref(self.get_queue_ref(), arg_objects, len(args), &Eref, 1):
925+
args_raw = <PyObject *>args
926+
927+
ret = -1
928+
htERef = async_dec_ref(self.get_queue_ref(), &args_raw, 1, &Eref, 1, &ret)
929+
if ret:
847930
# async task submission failed, decrement ref counts and wait
848-
for i in range(len(args)):
849-
arg_objects[i] = <PyObject *>(args[i])
850-
Py_DECREF(<object> arg_objects[i])
851-
with nogil: DPCTLEvent_Wait(Eref)
931+
Py_DECREF(args)
932+
with nogil:
933+
DPCTLEvent_Wait(Eref)
934+
DPCTLEvent_Wait(htERef)
852935

853-
# free memory
854-
free(arg_objects)
936+
# we are not returning host-task event at the moment
937+
DPCTLEvent_Delete(htERef)
855938

856939
return SyclEvent._create(Eref)
857940

0 commit comments

Comments
 (0)