15
15
#include " logger/ur_logger.hpp"
16
16
#include " queue_handle.hpp"
17
17
18
+ ur_result_t ur_execution_event_handle_t::assign (ur_event_handle_t hNewEvent) {
19
+ assert (hNewEvent);
20
+ assert (hNewEvent->getURQueueHandle ());
21
+
22
+ auto newQueue = hNewEvent->getURQueueHandle ();
23
+ auto currentQueue = hEvent ? hEvent->getURQueueHandle () : nullptr ;
24
+
25
+ if (hEvent) {
26
+ UR_CALL (hEvent->release ());
27
+ }
28
+
29
+ hEvent = hNewEvent;
30
+
31
+ if (newQueue != currentQueue) {
32
+ if (currentQueue)
33
+ UR_CALL (currentQueue->queueRelease ());
34
+ UR_CALL (newQueue->queueRetain ());
35
+ }
36
+
37
+ return UR_RESULT_SUCCESS;
38
+ }
39
+
40
+ ur_event_handle_t ur_execution_event_handle_t::get () { return hEvent; }
41
+
42
+ ur_result_t ur_execution_event_handle_t::release () {
43
+ if (hEvent) {
44
+ assert (hEvent->getURQueueHandle ());
45
+
46
+ auto hQueue = hEvent->getURQueueHandle ();
47
+ UR_CALL_NOCHECK (hEvent->release ());
48
+ UR_CALL_NOCHECK (hQueue->queueRelease ());
49
+
50
+ hEvent = nullptr ;
51
+ }
52
+ return UR_RESULT_SUCCESS;
53
+ }
54
+
55
+ ur_execution_event_handle_t ::~ur_execution_event_handle_t () { release (); }
56
+
18
57
namespace {
19
58
20
59
ur_result_t getZeKernelWrapped (ur_kernel_handle_t kernel,
@@ -69,13 +108,12 @@ ur_exp_command_buffer_handle_t_::ur_exp_command_buffer_handle_t_(
69
108
: eventPool(context->getEventPoolCache (PoolCacheType::Regular)
70
109
.borrow(device->Id.value(),
71
110
isInOrder ? v2::EVENT_FLAGS_COUNTER : 0)),
72
- context(context), device(device),
111
+ context(context), device(device), currentExecution( nullptr ),
73
112
isUpdatable(desc ? desc->isUpdatable : false ),
74
113
isInOrder(desc ? desc->isInOrder : false ),
75
114
commandListManager(
76
115
context, device,
77
- std::forward<v2::raii::command_list_unique_handle>(commandList))
78
- {}
116
+ std::forward<v2::raii::command_list_unique_handle>(commandList)) {}
79
117
80
118
ur_exp_command_buffer_sync_point_t
81
119
ur_exp_command_buffer_handle_t_::getSyncPoint (ur_event_handle_t event) {
@@ -146,25 +184,16 @@ ur_result_t ur_exp_command_buffer_handle_t_::finalizeCommandBuffer() {
146
184
return UR_RESULT_SUCCESS;
147
185
}
148
186
ur_event_handle_t ur_exp_command_buffer_handle_t_::getExecutionEventUnlocked () {
149
- return currentExecution;
187
+ return currentExecution. get () ;
150
188
}
151
189
152
190
ur_result_t ur_exp_command_buffer_handle_t_::registerExecutionEventUnlocked (
153
191
ur_event_handle_t nextExecutionEvent) {
154
- if (currentExecution) {
155
- UR_CALL (currentExecution->release ());
156
- currentExecution = nullptr ;
157
- }
158
- if (nextExecutionEvent) {
159
- currentExecution = nextExecutionEvent;
160
- }
192
+ UR_CALL (currentExecution.assign (nextExecutionEvent));
161
193
return UR_RESULT_SUCCESS;
162
194
}
163
195
164
196
ur_exp_command_buffer_handle_t_::~ur_exp_command_buffer_handle_t_ () {
165
- if (currentExecution) {
166
- currentExecution->release ();
167
- }
168
197
for (auto &event : syncPoints) {
169
198
event->release ();
170
199
}
@@ -181,14 +210,13 @@ ur_result_t ur_exp_command_buffer_handle_t_::applyUpdateCommands(
181
210
this , device, context->getPlatform ()->ZeDriverGlobalOffsetExtensionFound ,
182
211
numUpdateCommands, updateCommands));
183
212
184
- if (currentExecution) {
213
+ if (currentExecution. get () ) {
185
214
// TODO: Move synchronization to command buffer enqueue
186
215
// it would require to remember the update commands and perform update
187
216
// before appending to the queue
188
217
ZE2UR_CALL (zeEventHostSynchronize,
189
- (currentExecution->getZeEvent (), UINT64_MAX));
190
- currentExecution->release ();
191
- currentExecution = nullptr ;
218
+ (currentExecution.get ()->getZeEvent (), UINT64_MAX));
219
+ UR_CALL (currentExecution.release ());
192
220
}
193
221
194
222
device_ptr_storage_t zeHandles;
0 commit comments