Skip to content

Commit ba29585

Browse files
nyallocSteffen Larsen
authored andcommitted
[PI][CUDA] Implementation of piEventSetCallback with tests
Signed-off-by: Stuart Adams <[email protected]> Signed-off-by: Steffen Larsen <[email protected]> Signed-off-by: Ruyman Reyes <[email protected]>
1 parent fc03fda commit ba29585

File tree

6 files changed

+242
-65
lines changed

6 files changed

+242
-65
lines changed

sycl/plugins/cuda/pi_cuda.cpp

Lines changed: 49 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,8 @@ pi_result _pi_event::start() {
136136
}
137137

138138
isStarted_ = true;
139+
// let observers know that the event is "submitted"
140+
trigger_callback(get_execution_status());
139141
return result;
140142
}
141143

@@ -165,6 +167,22 @@ pi_result _pi_event::record() {
165167

166168
try {
167169
result = PI_CHECK_ERROR(cuEventRecord(event_, cuStream));
170+
171+
result = cuda_piEventRetain(this);
172+
try {
173+
result = PI_CHECK_ERROR(cuLaunchHostFunc(
174+
cuStream,
175+
[](void *userData) {
176+
pi_event event = reinterpret_cast<pi_event>(userData);
177+
event->set_event_complete();
178+
cuda_piEventRelease(event);
179+
},
180+
this));
181+
} catch (...) {
182+
// If host function fails to enqueue we must release the event here
183+
result = cuda_piEventRelease(this);
184+
throw;
185+
}
168186
} catch (pi_result error) {
169187
result = error;
170188
}
@@ -185,6 +203,7 @@ pi_result _pi_event::wait() {
185203
if (is_native_event()) {
186204
try {
187205
retErr = PI_CHECK_ERROR(cuEventSynchronize(event_));
206+
isCompleted_ = true;
188207
} catch (pi_result error) {
189208
retErr = error;
190209
}
@@ -196,30 +215,12 @@ pi_result _pi_event::wait() {
196215
retErr = PI_SUCCESS;
197216
}
198217

199-
return retErr;
200-
}
201-
202-
pi_event_status _pi_event::get_execution_status() const noexcept {
218+
auto is_success = retErr == PI_SUCCESS;
219+
auto status = is_success ? get_execution_status() : pi_int32(retErr);
203220

204-
if (!is_recorded()) {
205-
return PI_EVENT_SUBMITTED;
206-
}
207-
208-
if (is_native_event()) {
209-
// native event status
210-
211-
auto status = cuEventQuery(get());
212-
if (status == CUDA_ERROR_NOT_READY) {
213-
return PI_EVENT_RUNNING;
214-
} else if (status != CUDA_SUCCESS) {
215-
cl::sycl::detail::pi::die("Invalid CUDA event status");
216-
}
217-
return PI_EVENT_COMPLETE;
218-
} else {
219-
// user event status
221+
trigger_callback(status);
220222

221-
return is_completed() ? PI_EVENT_COMPLETE : PI_EVENT_RUNNING;
222-
}
223+
return retErr;
223224
}
224225

225226
// iterates over the event wait list, returns correct pi_result error codes.
@@ -2256,24 +2257,21 @@ pi_result cuda_piEventGetInfo(pi_event event, pi_event_info param_name,
22562257

22572258
switch (param_name) {
22582259
case PI_EVENT_INFO_QUEUE:
2259-
return getInfo<pi_queue>(param_value_size, param_value,
2260-
param_value_size_ret, event->get_queue());
2260+
return getInfo(param_value_size, param_value, param_value_size_ret,
2261+
event->get_queue());
22612262
case PI_EVENT_INFO_COMMAND_TYPE:
2262-
return getInfo<pi_command_type>(param_value_size, param_value,
2263-
param_value_size_ret,
2264-
event->get_command_type());
2263+
return getInfo(param_value_size, param_value, param_value_size_ret,
2264+
event->get_command_type());
22652265
case PI_EVENT_INFO_REFERENCE_COUNT:
2266-
return getInfo<pi_uint32>(param_value_size, param_value,
2267-
param_value_size_ret,
2268-
event->get_reference_count());
2266+
return getInfo(param_value_size, param_value, param_value_size_ret,
2267+
event->get_reference_count());
22692268
case PI_EVENT_INFO_COMMAND_EXECUTION_STATUS: {
2270-
return getInfo<pi_event_status>(param_value_size, param_value,
2271-
param_value_size_ret,
2272-
event->get_execution_status());
2269+
return getInfo(param_value_size, param_value, param_value_size_ret,
2270+
static_cast<pi_event_status>(event->get_execution_status()));
22732271
}
22742272
case PI_EVENT_INFO_CONTEXT:
2275-
return getInfo<pi_context>(param_value_size, param_value,
2276-
param_value_size_ret, event->get_context());
2273+
return getInfo(param_value_size, param_value, param_value_size_ret,
2274+
event->get_context());
22772275
default:
22782276
PI_HANDLE_UNKNOWN_PARAM_NAME(param_name);
22792277
}
@@ -2304,13 +2302,21 @@ pi_result cuda_piEventGetProfilingInfo(
23042302
return {};
23052303
}
23062304

2307-
pi_result cuda_piEventSetCallback(
2308-
pi_event event, pi_int32 command_exec_callback_type,
2309-
void (*pfn_notify)(pi_event event, pi_int32 event_command_status,
2310-
void *user_data),
2311-
void *user_data) {
2312-
cl::sycl::detail::pi::die("cuda_piEventSetCallback not implemented");
2313-
return {};
2305+
pi_result cuda_piEventSetCallback(pi_event event,
2306+
pi_int32 command_exec_callback_type,
2307+
pfn_notify notify, void *user_data) {
2308+
2309+
assert(event);
2310+
assert(notify);
2311+
assert(command_exec_callback_type == PI_EVENT_SUBMITTED ||
2312+
command_exec_callback_type == PI_EVENT_RUNNING ||
2313+
command_exec_callback_type == PI_EVENT_COMPLETE);
2314+
event_callback callback(pi_event_status(command_exec_callback_type), notify,
2315+
user_data);
2316+
2317+
event->set_event_callback(callback);
2318+
2319+
return PI_SUCCESS;
23142320
}
23152321

23162322
pi_result cuda_piEventSetStatus(pi_event event, pi_int32 execution_status) {
@@ -2323,7 +2329,7 @@ pi_result cuda_piEventSetStatus(pi_event event, pi_int32 execution_status) {
23232329
}
23242330

23252331
if (execution_status == PI_EVENT_COMPLETE) {
2326-
return event->set_user_event_complete();
2332+
return event->set_event_complete();
23272333
} else if (execution_status < 0) {
23282334
// TODO: A negative integer value causes all enqueued commands that wait
23292335
// on this user event to be terminated.

sycl/plugins/cuda/pi_cuda.hpp

Lines changed: 87 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,39 @@ struct _pi_queue {
228228
pi_uint32 get_reference_count() const noexcept { return refCount_; }
229229
};
230230

231+
typedef void (*pfn_notify)(pi_event event, pi_int32 eventCommandStatus,
232+
void *userData);
233+
234+
class event_callback {
235+
public:
236+
void trigger_callback(pi_event event, pi_int32 currentEventStatus) const {
237+
238+
auto validParameters = callback_ && event;
239+
240+
// As a pi_event_status value approaches 0, it gets closer to completion.
241+
// If the calling pi_event's status is less than or equal to the event
242+
// status the user is interested in, invoke the callback anyway. The event
243+
// will have passed through that state anyway.
244+
auto validStatus = currentEventStatus <= observedEventStatus_;
245+
246+
if (validParameters && validStatus) {
247+
248+
callback_(event, currentEventStatus, userData_);
249+
}
250+
}
251+
252+
event_callback(pi_event_status status, pfn_notify callback, void *userData)
253+
: observedEventStatus_{status}, callback_{callback}, userData_{userData} {
254+
}
255+
256+
pi_event_status get_status() const noexcept { return observedEventStatus_; }
257+
258+
private:
259+
pi_event_status observedEventStatus_;
260+
pfn_notify callback_;
261+
void *userData_;
262+
};
263+
231264
class _pi_event {
232265
public:
233266
using native_type = CUevent;
@@ -240,18 +273,39 @@ class _pi_event {
240273

241274
native_type get() const noexcept { return event_; };
242275

243-
pi_result set_user_event_complete() noexcept {
276+
pi_result set_event_complete() noexcept {
244277

245278
if (isCompleted_) {
246279
return PI_INVALID_OPERATION;
247280
}
248281

249-
if (is_user_event()) {
250-
isRecorded_ = true;
251-
isCompleted_ = true;
252-
return PI_SUCCESS;
282+
isRecorded_ = true;
283+
isCompleted_ = true;
284+
285+
trigger_callback(get_execution_status());
286+
287+
return PI_SUCCESS;
288+
}
289+
290+
void trigger_callback(pi_int32 status) {
291+
292+
std::vector<event_callback> callbacks;
293+
294+
// Here we move all callbacks into local variable before we call them.
295+
// This is a defensive maneuver; if any of the callbacks attempt to
296+
// add additional callbacks, we will end up in a bad spot. Our mutex
297+
// will be locked twice and the vector will be modified as it is being
298+
// iterated over! By moving everything locally, we can call all of these
299+
// callbacks and let them modify the original vector without much worry.
300+
301+
{
302+
std::lock_guard<std::mutex> lock(mutex_);
303+
event_callbacks_.swap(callbacks);
304+
}
305+
306+
for (auto &event_callback : callbacks) {
307+
event_callback.trigger_callback(this, status);
253308
}
254-
return PI_INVALID_EVENT;
255309
}
256310

257311
pi_queue get_queue() const noexcept { return queue_; }
@@ -266,7 +320,27 @@ class _pi_event {
266320

267321
bool is_started() const noexcept { return isStarted_; }
268322

269-
pi_event_status get_execution_status() const noexcept;
323+
pi_int32 get_execution_status() const noexcept {
324+
325+
if (!is_recorded()) {
326+
return PI_EVENT_SUBMITTED;
327+
}
328+
329+
if (!is_completed()) {
330+
return PI_EVENT_RUNNING;
331+
}
332+
return PI_EVENT_COMPLETE;
333+
}
334+
335+
void set_event_callback(const event_callback &callback) {
336+
auto current_status = get_execution_status();
337+
if (current_status <= callback.get_status()) {
338+
callback.trigger_callback(this, current_status);
339+
} else {
340+
std::lock_guard<std::mutex> lock(mutex_);
341+
event_callbacks_.emplace_back(callback);
342+
}
343+
}
270344

271345
pi_context get_context() const noexcept { return context_; };
272346

@@ -326,6 +400,12 @@ class _pi_event {
326400
pi_context context_; // pi_context associated with the event. If this is a
327401
// native event, this will be the same context associated
328402
// with the queue_ member.
403+
404+
std::mutex mutex_; // Protect access to event_callbacks_. TODO: There might be
405+
// a lock-free data structure we can use here.
406+
std::vector<event_callback>
407+
event_callbacks_; // Callbacks that can be triggered when an event's state
408+
// changes.
329409
};
330410

331411
struct _pi_program {

sycl/test/basic_tests/buffer/buffer_dev_to_dev.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,6 @@
44
// RUN: %GPU_RUN_PLACEHOLDER %t.out
55
// RUN: %ACC_RUN_PLACEHOLDER %t.out
66

7-
// TODO: pi_die: cuda_piEventSetCallback not implemented
8-
// XFAIL: cuda
9-
107
//==---------- buffer_dev_to_dev.cpp - SYCL buffer basic test --------------==//
118
//
129
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.

sycl/test/scheduler/DataMovement.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
// RUN: %clangxx -fsycl -fsycl-targets=%sycl_triple -I %sycl_source_dir %s -o %t.out
22
// RUN: %t.out
33
//
4-
// XFAIL: cuda
54
//==-------------------------- DataMovement.cpp ----------------------------==//
65
//
76
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.

sycl/test/scheduler/MultipleDevices.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
// RUN: %clangxx -fsycl -fsycl-targets=%sycl_triple -I %sycl_source_dir %s -o %t.out
22
// RUN: %t.out
33

4-
// TODO: pi_die: cuda_piEventSetCallback not implemented
5-
// XFAIL: cuda
6-
74
//===- MultipleDevices.cpp - Test checking multi-device execution --------===//
85
//
96
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.

0 commit comments

Comments
 (0)