@@ -64,22 +64,41 @@ inline void assign_result(pi_result *ptr, pi_result value) noexcept {
64
64
}
65
65
66
66
// Iterates over the event wait list, returns correct pi_result error codes.
67
- // Invokes the callback for each event in the wait list. The callback must take
68
- // a single pi_event argument and return a pi_result.
67
+ // Invokes the callback for the latest event of each queue in the wait list.
68
+ // The callback must take a single pi_event argument and return a pi_result.
69
69
template <typename Func>
70
- pi_result forEachEvent (const pi_event *event_wait_list,
71
- std::size_t num_events_in_wait_list, Func &&f) {
70
+ pi_result forLatestEvents (const pi_event *event_wait_list,
71
+ std::size_t num_events_in_wait_list, Func &&f) {
72
72
73
73
if (event_wait_list == nullptr || num_events_in_wait_list == 0 ) {
74
74
return PI_INVALID_EVENT_WAIT_LIST;
75
75
}
76
76
77
- for (size_t i = 0 ; i < num_events_in_wait_list; i++) {
78
- auto event = event_wait_list[i];
79
- if (event == nullptr ) {
80
- return PI_INVALID_EVENT_WAIT_LIST;
77
+ // Fast path if we only have a single event
78
+ if (num_events_in_wait_list == 1 ) {
79
+ return f (event_wait_list[0 ]);
80
+ }
81
+
82
+ std::vector<pi_event> events{event_wait_list,
83
+ event_wait_list + num_events_in_wait_list};
84
+ std::sort (events.begin (), events.end (), [](pi_event e0 , pi_event e1 ) {
85
+ // Tiered sort creating sublists of streams (smallest value first) in which
86
+ // the corresponding events are sorted into a sequence of newest first.
87
+ return e0 ->get_queue ()->stream_ < e1 ->get_queue ()->stream_ ||
88
+ (e0 ->get_queue ()->stream_ == e1 ->get_queue ()->stream_ &&
89
+ e0 ->get_event_id () > e1 ->get_event_id ());
90
+ });
91
+
92
+ bool first = true ;
93
+ CUstream lastSeenStream = 0 ;
94
+ for (pi_event event : events) {
95
+ if (!event || (!first && event->get_queue ()->stream_ == lastSeenStream)) {
96
+ continue ;
81
97
}
82
98
99
+ first = false ;
100
+ lastSeenStream = event->get_queue ()->stream_ ;
101
+
83
102
auto result = f (event);
84
103
if (result != PI_SUCCESS) {
85
104
return result;
@@ -357,6 +376,11 @@ pi_result _pi_event::record() {
357
376
CUstream cuStream = queue_->get ();
358
377
359
378
try {
379
+ eventId_ = queue_->get_next_event_id ();
380
+ if (eventId_ == 0 ) {
381
+ cl::sycl::detail::pi::die (
382
+ " Unrecoverable program state reached in event identifier overflow" );
383
+ }
360
384
result = PI_CHECK_ERROR (cuEventRecord (evEnd_, cuStream));
361
385
} catch (pi_result error) {
362
386
result = error;
@@ -1961,8 +1985,8 @@ pi_result cuda_piEnqueueMemBufferRead(pi_queue command_queue, pi_mem buffer,
1961
1985
pi_result cuda_piEventsWait (pi_uint32 num_events, const pi_event *event_list) {
1962
1986
1963
1987
try {
1964
- pi_result err = PI_SUCCESS ;
1965
-
1988
+ assert (num_events != 0 ) ;
1989
+ assert (event_list);
1966
1990
if (num_events == 0 ) {
1967
1991
return PI_INVALID_VALUE;
1968
1992
}
@@ -1974,11 +1998,7 @@ pi_result cuda_piEventsWait(pi_uint32 num_events, const pi_event *event_list) {
1974
1998
auto context = event_list[0 ]->get_context ();
1975
1999
ScopedContext active (context);
1976
2000
1977
- for (pi_uint32 count = 0 ; count < num_events && (err == PI_SUCCESS);
1978
- count++) {
1979
-
1980
- auto event = event_list[count];
1981
-
2001
+ auto waitFunc = [context](pi_event event) -> pi_result {
1982
2002
if (!event) {
1983
2003
return PI_INVALID_EVENT;
1984
2004
}
@@ -1987,9 +2007,9 @@ pi_result cuda_piEventsWait(pi_uint32 num_events, const pi_event *event_list) {
1987
2007
return PI_INVALID_CONTEXT;
1988
2008
}
1989
2009
1990
- err = event->wait ();
1991
- }
1992
- return err ;
2010
+ return event->wait ();
2011
+ };
2012
+ return forLatestEvents (event_list, num_events, waitFunc) ;
1993
2013
} catch (pi_result err) {
1994
2014
return err;
1995
2015
} catch (...) {
@@ -2763,10 +2783,10 @@ pi_result cuda_piEnqueueEventsWait(pi_queue command_queue,
2763
2783
2764
2784
if (event_wait_list) {
2765
2785
auto result =
2766
- forEachEvent (event_wait_list, num_events_in_wait_list,
2767
- [command_queue](pi_event event) -> pi_result {
2768
- return enqueueEventWait (command_queue, event);
2769
- });
2786
+ forLatestEvents (event_wait_list, num_events_in_wait_list,
2787
+ [command_queue](pi_event event) -> pi_result {
2788
+ return enqueueEventWait (command_queue, event);
2789
+ });
2770
2790
2771
2791
if (result != PI_SUCCESS) {
2772
2792
return result;
0 commit comments