Skip to content

Commit b5fd007

Browse files
nyallocSteffen Larsen
authored andcommitted
[PI][CUDA] Implementation of piEventSetCallback with tests
Signed-off-by: Stuart Adams <[email protected]> Signed-off-by: Steffen Larsen <[email protected]> Signed-off-by: Ruyman Reyes <[email protected]>
1 parent fcadd26 commit b5fd007

File tree

6 files changed

+218
-57
lines changed

6 files changed

+218
-57
lines changed

sycl/plugins/cuda/pi_cuda.cpp

Lines changed: 49 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,8 @@ pi_result _pi_event::start() {
149149
}
150150

151151
isStarted_ = true;
152+
// let observers know that the event is "submitted"
153+
trigger_callback(get_execution_status());
152154
return result;
153155
}
154156

@@ -195,6 +197,22 @@ pi_result _pi_event::record() {
195197

196198
try {
197199
result = PI_CHECK_ERROR(cuEventRecord(evEnd_, cuStream));
200+
201+
result = cuda_piEventRetain(this);
202+
try {
203+
result = PI_CHECK_ERROR(cuLaunchHostFunc(
204+
cuStream,
205+
[](void *userData) {
206+
pi_event event = reinterpret_cast<pi_event>(userData);
207+
event->set_event_complete();
208+
cuda_piEventRelease(event);
209+
},
210+
this));
211+
} catch (...) {
212+
// If host function fails to enqueue we must release the event here
213+
result = cuda_piEventRelease(this);
214+
throw;
215+
}
198216
} catch (pi_result error) {
199217
result = error;
200218
}
@@ -215,6 +233,7 @@ pi_result _pi_event::wait() {
215233
if (is_native_event()) {
216234
try {
217235
retErr = PI_CHECK_ERROR(cuEventSynchronize(evEnd_));
236+
isCompleted_ = true;
218237
} catch (pi_result error) {
219238
retErr = error;
220239
}
@@ -226,30 +245,12 @@ pi_result _pi_event::wait() {
226245
retErr = PI_SUCCESS;
227246
}
228247

229-
return retErr;
230-
}
231-
232-
pi_event_status _pi_event::get_execution_status() const noexcept {
248+
auto is_success = retErr == PI_SUCCESS;
249+
auto status = is_success ? get_execution_status() : pi_int32(retErr);
233250

234-
if (!is_recorded()) {
235-
return PI_EVENT_SUBMITTED;
236-
}
237-
238-
if (is_native_event()) {
239-
// native event status
240-
241-
auto status = cuEventQuery(get());
242-
if (status == CUDA_ERROR_NOT_READY) {
243-
return PI_EVENT_RUNNING;
244-
} else if (status != CUDA_SUCCESS) {
245-
cl::sycl::detail::pi::die("Invalid CUDA event status");
246-
}
247-
return PI_EVENT_COMPLETE;
248-
} else {
249-
// user event status
251+
trigger_callback(status);
250252

251-
return is_completed() ? PI_EVENT_COMPLETE : PI_EVENT_RUNNING;
252-
}
253+
return retErr;
253254
}
254255

255256
// iterates over the event wait list, returns correct pi_result error codes.
@@ -2516,24 +2517,21 @@ pi_result cuda_piEventGetInfo(pi_event event, pi_event_info param_name,
25162517

25172518
switch (param_name) {
25182519
case PI_EVENT_INFO_COMMAND_QUEUE:
2519-
return getInfo<pi_queue>(param_value_size, param_value,
2520-
param_value_size_ret, event->get_queue());
2520+
return getInfo(param_value_size, param_value, param_value_size_ret,
2521+
event->get_queue());
25212522
case PI_EVENT_INFO_COMMAND_TYPE:
2522-
return getInfo<pi_command_type>(param_value_size, param_value,
2523-
param_value_size_ret,
2524-
event->get_command_type());
2523+
return getInfo(param_value_size, param_value, param_value_size_ret,
2524+
event->get_command_type());
25252525
case PI_EVENT_INFO_REFERENCE_COUNT:
2526-
return getInfo<pi_uint32>(param_value_size, param_value,
2527-
param_value_size_ret,
2528-
event->get_reference_count());
2526+
return getInfo(param_value_size, param_value, param_value_size_ret,
2527+
event->get_reference_count());
25292528
case PI_EVENT_INFO_COMMAND_EXECUTION_STATUS: {
2530-
return getInfo<pi_event_status>(param_value_size, param_value,
2531-
param_value_size_ret,
2532-
event->get_execution_status());
2529+
return getInfo(param_value_size, param_value, param_value_size_ret,
2530+
static_cast<pi_event_status>(event->get_execution_status()));
25332531
}
25342532
case PI_EVENT_INFO_CONTEXT:
2535-
return getInfo<pi_context>(param_value_size, param_value,
2536-
param_value_size_ret, event->get_context());
2533+
return getInfo(param_value_size, param_value, param_value_size_ret,
2534+
event->get_context());
25372535
default:
25382536
PI_HANDLE_UNKNOWN_PARAM_NAME(param_name);
25392537
}
@@ -2568,13 +2566,21 @@ pi_result cuda_piEventGetProfilingInfo(
25682566
return {};
25692567
}
25702568

2571-
pi_result cuda_piEventSetCallback(
2572-
pi_event event, pi_int32 command_exec_callback_type,
2573-
void (*pfn_notify)(pi_event event, pi_int32 event_command_status,
2574-
void *user_data),
2575-
void *user_data) {
2576-
cl::sycl::detail::pi::die("cuda_piEventSetCallback not implemented");
2577-
return {};
2569+
pi_result cuda_piEventSetCallback(pi_event event,
2570+
pi_int32 command_exec_callback_type,
2571+
pfn_notify notify, void *user_data) {
2572+
2573+
assert(event);
2574+
assert(notify);
2575+
assert(command_exec_callback_type == PI_EVENT_SUBMITTED ||
2576+
command_exec_callback_type == PI_EVENT_RUNNING ||
2577+
command_exec_callback_type == PI_EVENT_COMPLETE);
2578+
event_callback callback(pi_event_status(command_exec_callback_type), notify,
2579+
user_data);
2580+
2581+
event->set_event_callback(callback);
2582+
2583+
return PI_SUCCESS;
25782584
}
25792585

25802586
pi_result cuda_piEventSetStatus(pi_event event, pi_int32 execution_status) {
@@ -2587,7 +2593,7 @@ pi_result cuda_piEventSetStatus(pi_event event, pi_int32 execution_status) {
25872593
}
25882594

25892595
if (execution_status == PI_EVENT_COMPLETE) {
2590-
return event->set_user_event_complete();
2596+
return event->set_event_complete();
25912597
} else if (execution_status < 0) {
25922598
// TODO: A negative integer value causes all enqueued commands that wait
25932599
// on this user event to be terminated.

sycl/plugins/cuda/pi_cuda.hpp

Lines changed: 87 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,39 @@ struct _pi_queue {
235235
pi_uint32 get_reference_count() const noexcept { return refCount_; }
236236
};
237237

238+
typedef void (*pfn_notify)(pi_event event, pi_int32 eventCommandStatus,
239+
void *userData);
240+
241+
class event_callback {
242+
public:
243+
void trigger_callback(pi_event event, pi_int32 currentEventStatus) const {
244+
245+
auto validParameters = callback_ && event;
246+
247+
// As a pi_event_status value approaches 0, it gets closer to completion.
248+
// If the calling pi_event's status is less than or equal to the event
249+
// status the user is interested in, invoke the callback anyway. The event
250+
// will have passed through that state anyway.
251+
auto validStatus = currentEventStatus <= observedEventStatus_;
252+
253+
if (validParameters && validStatus) {
254+
255+
callback_(event, currentEventStatus, userData_);
256+
}
257+
}
258+
259+
event_callback(pi_event_status status, pfn_notify callback, void *userData)
260+
: observedEventStatus_{status}, callback_{callback}, userData_{userData} {
261+
}
262+
263+
pi_event_status get_status() const noexcept { return observedEventStatus_; }
264+
265+
private:
266+
pi_event_status observedEventStatus_;
267+
pfn_notify callback_;
268+
void *userData_;
269+
};
270+
238271
class _pi_event {
239272
public:
240273
using native_type = CUevent;
@@ -247,18 +280,39 @@ class _pi_event {
247280

248281
native_type get() const noexcept { return evEnd_; };
249282

250-
pi_result set_user_event_complete() noexcept {
283+
pi_result set_event_complete() noexcept {
251284

252285
if (isCompleted_) {
253286
return PI_INVALID_OPERATION;
254287
}
255288

256-
if (is_user_event()) {
257-
isRecorded_ = true;
258-
isCompleted_ = true;
259-
return PI_SUCCESS;
289+
isRecorded_ = true;
290+
isCompleted_ = true;
291+
292+
trigger_callback(get_execution_status());
293+
294+
return PI_SUCCESS;
295+
}
296+
297+
void trigger_callback(pi_int32 status) {
298+
299+
std::vector<event_callback> callbacks;
300+
301+
// Here we move all callbacks into local variable before we call them.
302+
// This is a defensive maneuver; if any of the callbacks attempt to
303+
// add additional callbacks, we will end up in a bad spot. Our mutex
304+
// will be locked twice and the vector will be modified as it is being
305+
// iterated over! By moving everything locally, we can call all of these
306+
// callbacks and let them modify the original vector without much worry.
307+
308+
{
309+
std::lock_guard<std::mutex> lock(mutex_);
310+
event_callbacks_.swap(callbacks);
311+
}
312+
313+
for (auto &event_callback : callbacks) {
314+
event_callback.trigger_callback(this, status);
260315
}
261-
return PI_INVALID_EVENT;
262316
}
263317

264318
pi_queue get_queue() const noexcept { return queue_; }
@@ -273,7 +327,27 @@ class _pi_event {
273327

274328
bool is_started() const noexcept { return isStarted_; }
275329

276-
pi_event_status get_execution_status() const noexcept;
330+
pi_int32 get_execution_status() const noexcept {
331+
332+
if (!is_recorded()) {
333+
return PI_EVENT_SUBMITTED;
334+
}
335+
336+
if (!is_completed()) {
337+
return PI_EVENT_RUNNING;
338+
}
339+
return PI_EVENT_COMPLETE;
340+
}
341+
342+
void set_event_callback(const event_callback &callback) {
343+
auto current_status = get_execution_status();
344+
if (current_status <= callback.get_status()) {
345+
callback.trigger_callback(this, current_status);
346+
} else {
347+
std::lock_guard<std::mutex> lock(mutex_);
348+
event_callbacks_.emplace_back(callback);
349+
}
350+
}
277351

278352
pi_context get_context() const noexcept { return context_; };
279353

@@ -343,6 +417,12 @@ class _pi_event {
343417
pi_context context_; // pi_context associated with the event. If this is a
344418
// native event, this will be the same context associated
345419
// with the queue_ member.
420+
421+
std::mutex mutex_; // Protect access to event_callbacks_. TODO: There might be
422+
// a lock-free data structure we can use here.
423+
std::vector<event_callback>
424+
event_callbacks_; // Callbacks that can be triggered when an event's state
425+
// changes.
346426
};
347427

348428
struct _pi_program {

sycl/test/basic_tests/buffer/buffer_dev_to_dev.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,6 @@
44
// RUN: %GPU_RUN_PLACEHOLDER %t.out
55
// RUN: %ACC_RUN_PLACEHOLDER %t.out
66

7-
// TODO: pi_die: cuda_piEventSetCallback not implemented
8-
// XFAIL: cuda
9-
107
//==---------- buffer_dev_to_dev.cpp - SYCL buffer basic test --------------==//
118
//
129
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.

sycl/test/scheduler/DataMovement.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
// RUN: %clangxx -fsycl -fsycl-targets=%sycl_triple -I %sycl_source_dir %s -o %t.out
22
// RUN: %t.out
33
//
4-
// XFAIL: cuda
54
//==-------------------------- DataMovement.cpp ----------------------------==//
65
//
76
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.

sycl/test/scheduler/MultipleDevices.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
// RUN: %clangxx -fsycl -fsycl-targets=%sycl_triple -I %sycl_source_dir %s -o %t.out
22
// RUN: %t.out
33

4-
// TODO: pi_die: cuda_piEventSetCallback not implemented
5-
// XFAIL: cuda
6-
74
//===- MultipleDevices.cpp - Test checking multi-device execution --------===//
85
//
96
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.

sycl/unittests/pi/EventTest.cpp

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,88 @@ TEST_F(DISABLED_EventTest, PICreateEvent) {
8989
PI_SUCCESS);
9090
}
9191

92+
constexpr size_t event_type_count = 3;
93+
static bool triggered_flag[event_type_count] = {false, false, false};
94+
95+
struct callback_user_data {
96+
pi_int32 event_type;
97+
int index;
98+
};
99+
100+
void EventCallback(pi_event event, pi_int32 status, void *data) {
101+
ASSERT_NE(data, nullptr);
102+
103+
callback_user_data *pdata = static_cast<callback_user_data *>(data);
104+
105+
#ifndef NDEBUG
106+
printf("\tEvent callback %d of type %d triggered\n", pdata->index,
107+
pdata->event_type);
108+
#endif
109+
110+
triggered_flag[pdata->index] = true;
111+
}
112+
113+
TEST_F(DISABLED_EventTest, piEventSetCallback) {
114+
115+
pi_int32 event_callback_types[event_type_count] = {
116+
PI_EVENT_SUBMITTED, PI_EVENT_RUNNING, PI_EVENT_COMPLETE};
117+
118+
callback_user_data user_data[event_type_count];
119+
120+
// gate event lets us register callbacks before letting the enqueued work be
121+
// executed.
122+
pi_event gateEvent;
123+
ASSERT_EQ((Plugins[0].call_nocheck<detail::PiApiKind::piEventCreate>(
124+
_context, &gateEvent)),
125+
PI_SUCCESS);
126+
127+
constexpr const size_t dataCount = 1000u;
128+
std::vector<int> data(dataCount);
129+
auto size_in_bytes = data.size() * sizeof(int);
130+
131+
pi_mem memObj;
132+
ASSERT_EQ(
133+
(Plugins[0].call_nocheck<detail::PiApiKind::piMemBufferCreate>(
134+
_context, PI_MEM_FLAGS_ACCESS_RW, size_in_bytes, nullptr, &memObj)),
135+
PI_SUCCESS);
136+
137+
pi_event syncEvent;
138+
ASSERT_EQ(
139+
(Plugins[0].call_nocheck<detail::PiApiKind::piEnqueueMemBufferWrite>(
140+
_queue, memObj, false, 0, size_in_bytes, data.data(), 1, &gateEvent,
141+
&syncEvent)),
142+
PI_SUCCESS);
143+
144+
for (size_t i = 0; i < event_type_count; i++) {
145+
user_data[i].event_type = event_callback_types[i];
146+
user_data[i].index = i;
147+
ASSERT_EQ(
148+
(Plugins[0].call_nocheck<detail::PiApiKind::piEventSetCallback>(
149+
syncEvent, event_callback_types[i], EventCallback, user_data + i)),
150+
PI_SUCCESS);
151+
}
152+
153+
ASSERT_EQ((Plugins[0].call_nocheck<detail::PiApiKind::piEventSetStatus>(
154+
gateEvent, PI_EVENT_COMPLETE)),
155+
PI_SUCCESS);
156+
ASSERT_EQ(
157+
(Plugins[0].call_nocheck<detail::PiApiKind::piEventsWait>(1, &syncEvent)),
158+
PI_SUCCESS);
159+
ASSERT_EQ((Plugins[0].call_nocheck<detail::PiApiKind::piQueueFinish>(_queue)),
160+
PI_SUCCESS);
161+
162+
for (size_t k = 0; k < event_type_count; ++k) {
163+
EXPECT_TRUE(triggered_flag[k]);
164+
}
165+
166+
ASSERT_EQ(
167+
(Plugins[0].call_nocheck<detail::PiApiKind::piEventRelease>(gateEvent)),
168+
PI_SUCCESS);
169+
ASSERT_EQ(
170+
(Plugins[0].call_nocheck<detail::PiApiKind::piEventRelease>(syncEvent)),
171+
PI_SUCCESS);
172+
}
173+
92174
TEST_F(DISABLED_EventTest, piEventGetInfo) {
93175

94176
pi_event foo;

0 commit comments

Comments
 (0)