Skip to content

Commit 86acff3

Browse files
authored
[SYCL][CUDA] Expose context extended deleters on PI API (#1483)
Signed-off-by: Stuart Adams [email protected]
1 parent 76e3c46 commit 86acff3

File tree

6 files changed

+49
-11
lines changed

6 files changed

+49
-11
lines changed

sycl/include/CL/sycl/detail/pi.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ _PI_API(piContextCreate)
3131
_PI_API(piContextGetInfo)
3232
_PI_API(piContextRetain)
3333
_PI_API(piContextRelease)
34+
_PI_API(piextContextSetExtendedDeleter)
3435
// Queue
3536
_PI_API(piQueueCreate)
3637
_PI_API(piQueueGetInfo)

sycl/include/CL/sycl/detail/pi.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -829,6 +829,12 @@ pi_result piContextRetain(pi_context context);
829829

830830
pi_result piContextRelease(pi_context context);
831831

832+
typedef void (*pi_context_extended_deleter)(void *user_data);
833+
834+
pi_result piextContextSetExtendedDeleter(pi_context context,
835+
pi_context_extended_deleter func,
836+
void *user_data);
837+
832838
//
833839
// Queue
834840
//

sycl/include/CL/sycl/detail/pi.hpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ struct trace_event_data_t;
3131

3232
__SYCL_INLINE_NAMESPACE(cl) {
3333
namespace sycl {
34+
35+
class context;
36+
3437
namespace detail {
3538

3639
enum class PiApiKind {
@@ -96,6 +99,10 @@ using PiMemObjectType = ::pi_mem_type;
9699
using PiMemImageChannelOrder = ::pi_image_channel_order;
97100
using PiMemImageChannelType = ::pi_image_channel_type;
98101

102+
void contextSetExtendedDeleter(const cl::sycl::context &constext,
103+
pi_context_extended_deleter func,
104+
void *user_data);
105+
99106
// Function to load the shared library
100107
// Implementation is OS dependent.
101108
void *loadOsLibrary(const std::string &Library);

sycl/plugins/cuda/pi_cuda.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -777,6 +777,12 @@ pi_result cuda_piContextRetain(pi_context context) {
777777
return PI_SUCCESS;
778778
}
779779

780+
pi_result cuda_piextContextSetExtendedDeleter(
781+
pi_context context, pi_context_extended_deleter function, void *user_data) {
782+
context->set_extended_deleter(function, user_data);
783+
return PI_SUCCESS;
784+
}
785+
780786
/// Not applicable to CUDA, devices cannot be partitioned.
781787
///
782788
pi_result cuda_piDevicePartition(
@@ -1459,7 +1465,7 @@ pi_result cuda_piContextRelease(pi_context ctxt) {
14591465
if (ctxt->decrement_reference_count() > 0) {
14601466
return PI_SUCCESS;
14611467
}
1462-
ctxt->invoke_callback();
1468+
ctxt->invoke_extended_deleters();
14631469

14641470
std::unique_ptr<_pi_context> context{ctxt};
14651471

@@ -3583,6 +3589,7 @@ pi_result piPluginInit(pi_plugin *PluginInit) {
35833589
_PI_CL(piextDeviceSelectBinary, cuda_piextDeviceSelectBinary)
35843590
_PI_CL(piextGetDeviceFunctionPointer, cuda_piextGetDeviceFunctionPointer)
35853591
// Context
3592+
_PI_CL(piextContextSetExtendedDeleter, cuda_piextContextSetExtendedDeleter)
35863593
_PI_CL(piContextCreate, cuda_piContextCreate)
35873594
_PI_CL(piContextGetInfo, cuda_piContextGetInfo)
35883595
_PI_CL(piContextRetain, cuda_piContextRetain)

sycl/plugins/cuda/pi_cuda.hpp

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,14 @@ class _pi_device {
121121
/// See proposal for details.
122122
///
123123
struct _pi_context {
124+
125+
struct deleter_data {
126+
pi_context_extended_deleter function;
127+
void *user_data;
128+
129+
void operator()() { function(user_data); }
130+
};
131+
124132
using native_type = CUcontext;
125133

126134
enum class kind { primary, user_defined } kind_;
@@ -138,20 +146,17 @@ struct _pi_context {
138146

139147
~_pi_context() { cuda_piDeviceRelease(deviceId_); }
140148

141-
void invoke_callback()
142-
{
149+
void invoke_extended_deleters() {
143150
std::lock_guard<std::mutex> guard(mutex_);
144-
for(const auto& callback : destruction_callbacks_)
145-
{
146-
callback();
151+
for (auto &deleter : extended_deleters_) {
152+
deleter();
147153
}
148154
}
149155

150-
template<typename Func>
151-
void register_callback(Func&& callback)
152-
{
156+
void set_extended_deleter(pi_context_extended_deleter function,
157+
void *user_data) {
153158
std::lock_guard<std::mutex> guard(mutex_);
154-
destruction_callbacks_.emplace_back(std::forward<Func>(callback));
159+
extended_deleters_.emplace_back(deleter_data{function, user_data});
155160
}
156161

157162
pi_device get_device() const noexcept { return deviceId_; }
@@ -168,7 +173,7 @@ struct _pi_context {
168173

169174
private:
170175
std::mutex mutex_;
171-
std::vector<std::function<void(void)>> destruction_callbacks_;
176+
std::vector<deleter_data> extended_deleters_;
172177
};
173178

174179
/// PI Mem mapping to a CUDA memory allocation

sycl/source/detail/pi.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
///
1212
/// \ingroup sycl_pi
1313

14+
#include "context_impl.hpp"
15+
#include <CL/sycl/context.hpp>
1416
#include <CL/sycl/detail/common.hpp>
1517
#include <CL/sycl/detail/pi.hpp>
1618
#include <detail/plugin.hpp>
@@ -53,6 +55,16 @@ namespace pi {
5355

5456
bool XPTIInitDone = false;
5557

58+
void contextSetExtendedDeleter(const cl::sycl::context &context,
59+
pi_context_extended_deleter func,
60+
void *user_data) {
61+
auto impl = getSyclObjImpl(context);
62+
auto contextHandle = reinterpret_cast<pi_context>(impl->getHandleRef());
63+
auto plugin = impl->getPlugin();
64+
plugin.call_nocheck<PiApiKind::piextContextSetExtendedDeleter>(
65+
contextHandle, func, user_data);
66+
}
67+
5668
std::string platformInfoToString(pi_platform_info info) {
5769
switch (info) {
5870
case PI_PLATFORM_INFO_PROFILE:

0 commit comments

Comments
 (0)