Skip to content

Commit dab89ae

Browse files
committed
Handle nullptr event in enqueue
1 parent 43e29fc commit dab89ae

File tree

1 file changed

+149
-147
lines changed

1 file changed

+149
-147
lines changed

source/adapters/native_cpu/enqueue.cpp

Lines changed: 149 additions & 147 deletions
Original file line numberDiff line numberDiff line change
@@ -258,26 +258,43 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
258258
return UR_RESULT_SUCCESS;
259259
}
260260

261+
ur_result_t withTimingEvent(ur_command_t command_type, ur_queue_handle_t hQueue,
262+
uint32_t numEventsInWaitList,
263+
const ur_event_handle_t *phEventWaitList,
264+
ur_event_handle_t *phEvent,
265+
const std::function<ur_result_t()> &f) {
266+
urEventWait(numEventsInWaitList, phEventWaitList);
267+
ur_event_handle_t event;
268+
if (phEvent) {
269+
event = new ur_event_handle_t_(hQueue, command_type);
270+
event->tick_start();
271+
}
272+
273+
ur_result_t result = f();
274+
275+
if (phEvent) {
276+
event->tick_end();
277+
*phEvent = event;
278+
}
279+
return result;
280+
}
281+
261282
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueEventsWait(
262283
ur_queue_handle_t hQueue, uint32_t numEventsInWaitList,
263284
const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
264-
std::ignore = hQueue;
265-
std::ignore = numEventsInWaitList;
266-
std::ignore = phEventWaitList;
267-
std::ignore = phEvent;
268285

269-
DIE_NO_IMPLEMENTATION;
286+
// TODO: the wait here should be async
287+
return withTimingEvent(UR_COMMAND_EVENTS_WAIT, hQueue, numEventsInWaitList,
288+
phEventWaitList, phEvent,
289+
[&]() { return UR_RESULT_SUCCESS; });
270290
}
271291

272292
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueEventsWaitWithBarrier(
273293
ur_queue_handle_t hQueue, uint32_t numEventsInWaitList,
274294
const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
275-
std::ignore = hQueue;
276-
std::ignore = numEventsInWaitList;
277-
std::ignore = phEventWaitList;
278-
std::ignore = phEvent;
279-
280-
DIE_NO_IMPLEMENTATION;
295+
return withTimingEvent(UR_COMMAND_EVENTS_WAIT_WITH_BARRIER, hQueue,
296+
numEventsInWaitList, phEventWaitList, phEvent,
297+
[&]() { return UR_RESULT_SUCCESS; });
281298
}
282299

283300
template <bool IsRead>
@@ -289,43 +306,42 @@ static inline ur_result_t enqueueMemBufferReadWriteRect_impl(
289306
typename std::conditional<IsRead, void *, const void *>::type DstMem,
290307
uint32_t NumEventsInWaitList, const ur_event_handle_t *phEventWaitList,
291308
ur_event_handle_t *phEvent) {
292-
ur_event_handle_t event;
309+
ur_command_t command_t;
293310
if constexpr (IsRead)
294-
event = new ur_event_handle_t_(hQueue, UR_COMMAND_MEM_BUFFER_READ_RECT);
311+
command_t = UR_COMMAND_MEM_BUFFER_READ_RECT;
295312
else
296-
event = new ur_event_handle_t_(hQueue, UR_COMMAND_MEM_BUFFER_WRITE_RECT);
297-
event->tick_start();
298-
// TODO: blocking, check other constraints, performance optimizations
299-
// More sharing with level_zero where possible
300-
301-
urEventWait(NumEventsInWaitList, phEventWaitList);
302-
if (BufferRowPitch == 0)
303-
BufferRowPitch = region.width;
304-
if (BufferSlicePitch == 0)
305-
BufferSlicePitch = BufferRowPitch * region.height;
306-
if (HostRowPitch == 0)
307-
HostRowPitch = region.width;
308-
if (HostSlicePitch == 0)
309-
HostSlicePitch = HostRowPitch * region.height;
310-
for (size_t w = 0; w < region.width; w++)
311-
for (size_t h = 0; h < region.height; h++)
312-
for (size_t d = 0; d < region.depth; d++) {
313-
size_t buff_orign = (d + BufferOffset.z) * BufferSlicePitch +
314-
(h + BufferOffset.y) * BufferRowPitch + w +
315-
BufferOffset.x;
316-
size_t host_origin = (d + HostOffset.z) * HostSlicePitch +
317-
(h + HostOffset.y) * HostRowPitch + w +
318-
HostOffset.x;
319-
int8_t &buff_mem = ur_cast<int8_t *>(Buff->_mem)[buff_orign];
320-
if constexpr (IsRead)
321-
ur_cast<int8_t *>(DstMem)[host_origin] = buff_mem;
322-
else
323-
buff_mem = ur_cast<const int8_t *>(DstMem)[host_origin];
324-
}
313+
command_t = UR_COMMAND_MEM_BUFFER_WRITE_RECT;
314+
return withTimingEvent(
315+
command_t, hQueue, NumEventsInWaitList, phEventWaitList, phEvent, [&]() {
316+
// TODO: blocking, check other constraints, performance optimizations
317+
// More sharing with level_zero where possible
318+
319+
if (BufferRowPitch == 0)
320+
BufferRowPitch = region.width;
321+
if (BufferSlicePitch == 0)
322+
BufferSlicePitch = BufferRowPitch * region.height;
323+
if (HostRowPitch == 0)
324+
HostRowPitch = region.width;
325+
if (HostSlicePitch == 0)
326+
HostSlicePitch = HostRowPitch * region.height;
327+
for (size_t w = 0; w < region.width; w++)
328+
for (size_t h = 0; h < region.height; h++)
329+
for (size_t d = 0; d < region.depth; d++) {
330+
size_t buff_orign = (d + BufferOffset.z) * BufferSlicePitch +
331+
(h + BufferOffset.y) * BufferRowPitch + w +
332+
BufferOffset.x;
333+
size_t host_origin = (d + HostOffset.z) * HostSlicePitch +
334+
(h + HostOffset.y) * HostRowPitch + w +
335+
HostOffset.x;
336+
int8_t &buff_mem = ur_cast<int8_t *>(Buff->_mem)[buff_orign];
337+
if constexpr (IsRead)
338+
ur_cast<int8_t *>(DstMem)[host_origin] = buff_mem;
339+
else
340+
buff_mem = ur_cast<const int8_t *>(DstMem)[host_origin];
341+
}
325342

326-
event->tick_end();
327-
*phEvent = event;
328-
return UR_RESULT_SUCCESS;
343+
return UR_RESULT_SUCCESS;
344+
});
329345
}
330346

331347
static inline ur_result_t doCopy_impl(ur_queue_handle_t hQueue, void *DstPtr,
@@ -334,15 +350,12 @@ static inline ur_result_t doCopy_impl(ur_queue_handle_t hQueue, void *DstPtr,
334350
const ur_event_handle_t *phEventWaitList,
335351
ur_event_handle_t *phEvent,
336352
ur_command_t command_type) {
337-
ur_event_handle_t event = new ur_event_handle_t_(hQueue, command_type);
338-
event->tick_start();
339-
urEventWait(numEventsInWaitList, phEventWaitList);
340-
if (SrcPtr != DstPtr && Size)
341-
memmove(DstPtr, SrcPtr, Size);
342-
event->tick_end();
343-
if (phEvent)
344-
*phEvent = event;
345-
return UR_RESULT_SUCCESS;
353+
return withTimingEvent(command_type, hQueue, numEventsInWaitList,
354+
phEventWaitList, phEvent, [&]() {
355+
if (SrcPtr != DstPtr && Size)
356+
memmove(DstPtr, SrcPtr, Size);
357+
return UR_RESULT_SUCCESS;
358+
});
346359
}
347360

348361
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferRead(
@@ -426,22 +439,23 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferFill(
426439
size_t patternSize, size_t offset, size_t size,
427440
uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
428441
ur_event_handle_t *phEvent) {
429-
std::ignore = numEventsInWaitList;
430-
std::ignore = phEventWaitList;
431-
std::ignore = phEvent;
432-
433-
UR_ASSERT(hQueue, UR_RESULT_ERROR_INVALID_NULL_HANDLE);
434442

435-
// TODO: error checking
436-
// TODO: handle async
437-
void *startingPtr = hBuffer->_mem + offset;
438-
unsigned steps = size / patternSize;
439-
for (unsigned i = 0; i < steps; i++) {
440-
memcpy(static_cast<int8_t *>(startingPtr) + i * patternSize, pPattern,
441-
patternSize);
442-
}
443+
return withTimingEvent(
444+
UR_COMMAND_MEM_BUFFER_FILL, hQueue, numEventsInWaitList, phEventWaitList,
445+
phEvent, [&]() {
446+
UR_ASSERT(hQueue, UR_RESULT_ERROR_INVALID_NULL_HANDLE);
447+
448+
// TODO: error checking
449+
// TODO: handle async
450+
void *startingPtr = hBuffer->_mem + offset;
451+
unsigned steps = size / patternSize;
452+
for (unsigned i = 0; i < steps; i++) {
453+
memcpy(static_cast<int8_t *>(startingPtr) + i * patternSize, pPattern,
454+
patternSize);
455+
}
443456

444-
return UR_RESULT_SUCCESS;
457+
return UR_RESULT_SUCCESS;
458+
});
445459
}
446460

447461
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageRead(
@@ -512,15 +526,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferMap(
512526
std::ignore = mapFlags;
513527
std::ignore = size;
514528

515-
urEventWait(numEventsInWaitList, phEventWaitList);
516-
ur_event_handle_t event =
517-
new ur_event_handle_t_(hQueue, UR_COMMAND_MEM_BUFFER_MAP);
518-
event->tick_start();
519-
*ppRetMap = hBuffer->_mem + offset;
520-
event->tick_end();
521-
*phEvent = event;
522-
523-
return UR_RESULT_SUCCESS;
529+
return withTimingEvent(UR_COMMAND_MEM_BUFFER_MAP, hQueue, numEventsInWaitList,
530+
phEventWaitList, phEvent, [&]() {
531+
*ppRetMap = hBuffer->_mem + offset;
532+
return UR_RESULT_SUCCESS;
533+
});
524534
}
525535

526536
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemUnmap(
@@ -529,91 +539,83 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemUnmap(
529539
ur_event_handle_t *phEvent) {
530540
std::ignore = hMem;
531541
std::ignore = pMappedPtr;
532-
urEventWait(numEventsInWaitList, phEventWaitList);
533-
*phEvent = new ur_event_handle_t_(hQueue, UR_COMMAND_MEM_UNMAP);
534-
535-
return UR_RESULT_SUCCESS;
542+
return withTimingEvent(UR_COMMAND_MEM_UNMAP, hQueue, numEventsInWaitList,
543+
phEventWaitList, phEvent,
544+
[&]() { return UR_RESULT_SUCCESS; });
536545
}
537546

538547
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMFill(
539548
ur_queue_handle_t hQueue, void *ptr, size_t patternSize,
540549
const void *pPattern, size_t size, uint32_t numEventsInWaitList,
541550
const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
542-
urEventWait(numEventsInWaitList, phEventWaitList);
543-
ur_event_handle_t event =
544-
new ur_event_handle_t_(hQueue, UR_COMMAND_MEM_BUFFER_MAP);
545-
event->tick_start();
546-
547-
UR_ASSERT(ptr, UR_RESULT_ERROR_INVALID_NULL_POINTER);
548-
UR_ASSERT(pPattern, UR_RESULT_ERROR_INVALID_NULL_POINTER);
549-
UR_ASSERT(patternSize != 0, UR_RESULT_ERROR_INVALID_SIZE)
550-
UR_ASSERT(size != 0, UR_RESULT_ERROR_INVALID_SIZE)
551-
UR_ASSERT(patternSize < size, UR_RESULT_ERROR_INVALID_SIZE)
552-
UR_ASSERT(size % patternSize == 0, UR_RESULT_ERROR_INVALID_SIZE)
553-
// TODO: add check for allocation size once the query is supported
554-
555-
switch (patternSize) {
556-
case 1:
557-
memset(ptr, *static_cast<const uint8_t *>(pPattern), size * patternSize);
558-
break;
559-
case 2: {
560-
const auto pattern = *static_cast<const uint16_t *>(pPattern);
561-
auto *start = reinterpret_cast<uint16_t *>(ptr);
562-
auto *end =
563-
reinterpret_cast<uint16_t *>(reinterpret_cast<uint8_t *>(ptr) + size);
564-
std::fill(start, end, pattern);
565-
break;
566-
}
567-
case 4: {
568-
const auto pattern = *static_cast<const uint32_t *>(pPattern);
569-
auto *start = reinterpret_cast<uint32_t *>(ptr);
570-
auto *end =
571-
reinterpret_cast<uint32_t *>(reinterpret_cast<uint8_t *>(ptr) + size);
572-
std::fill(start, end, pattern);
573-
break;
574-
}
575-
case 8: {
576-
const auto pattern = *static_cast<const uint64_t *>(pPattern);
577-
auto *start = reinterpret_cast<uint64_t *>(ptr);
578-
auto *end =
579-
reinterpret_cast<uint64_t *>(reinterpret_cast<uint8_t *>(ptr) + size);
580-
std::fill(start, end, pattern);
581-
break;
582-
}
583-
default: {
584-
for (unsigned int step{0}; step < size; step += patternSize) {
585-
auto *dest =
586-
reinterpret_cast<void *>(reinterpret_cast<uint8_t *>(ptr) + step);
587-
memcpy(dest, pPattern, patternSize);
588-
}
589-
}
590-
}
591-
592-
event->tick_end();
593-
*phEvent = event;
594-
595-
return UR_RESULT_SUCCESS;
551+
return withTimingEvent(
552+
UR_COMMAND_USM_FILL, hQueue, numEventsInWaitList, phEventWaitList,
553+
phEvent, [&]() {
554+
UR_ASSERT(ptr, UR_RESULT_ERROR_INVALID_NULL_POINTER);
555+
UR_ASSERT(pPattern, UR_RESULT_ERROR_INVALID_NULL_POINTER);
556+
UR_ASSERT(patternSize != 0, UR_RESULT_ERROR_INVALID_SIZE)
557+
UR_ASSERT(size != 0, UR_RESULT_ERROR_INVALID_SIZE)
558+
UR_ASSERT(patternSize < size, UR_RESULT_ERROR_INVALID_SIZE)
559+
UR_ASSERT(size % patternSize == 0, UR_RESULT_ERROR_INVALID_SIZE)
560+
// TODO: add check for allocation size once the query is supported
561+
562+
switch (patternSize) {
563+
case 1:
564+
memset(ptr, *static_cast<const uint8_t *>(pPattern),
565+
size * patternSize);
566+
break;
567+
case 2: {
568+
const auto pattern = *static_cast<const uint16_t *>(pPattern);
569+
auto *start = reinterpret_cast<uint16_t *>(ptr);
570+
auto *end = reinterpret_cast<uint16_t *>(
571+
reinterpret_cast<uint8_t *>(ptr) + size);
572+
std::fill(start, end, pattern);
573+
break;
574+
}
575+
case 4: {
576+
const auto pattern = *static_cast<const uint32_t *>(pPattern);
577+
auto *start = reinterpret_cast<uint32_t *>(ptr);
578+
auto *end = reinterpret_cast<uint32_t *>(
579+
reinterpret_cast<uint8_t *>(ptr) + size);
580+
std::fill(start, end, pattern);
581+
break;
582+
}
583+
case 8: {
584+
const auto pattern = *static_cast<const uint64_t *>(pPattern);
585+
auto *start = reinterpret_cast<uint64_t *>(ptr);
586+
auto *end = reinterpret_cast<uint64_t *>(
587+
reinterpret_cast<uint8_t *>(ptr) + size);
588+
std::fill(start, end, pattern);
589+
break;
590+
}
591+
default: {
592+
for (unsigned int step{0}; step < size; step += patternSize) {
593+
auto *dest = reinterpret_cast<void *>(
594+
reinterpret_cast<uint8_t *>(ptr) + step);
595+
memcpy(dest, pPattern, patternSize);
596+
}
597+
}
598+
}
599+
return UR_RESULT_SUCCESS;
600+
});
596601
}
597602

598603
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMMemcpy(
599604
ur_queue_handle_t hQueue, bool blocking, void *pDst, const void *pSrc,
600605
size_t size, uint32_t numEventsInWaitList,
601606
const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
602607
std::ignore = blocking;
603-
urEventWait(numEventsInWaitList, phEventWaitList);
604-
ur_event_handle_t event =
605-
new ur_event_handle_t_(hQueue, UR_COMMAND_MEM_BUFFER_MAP);
606-
event->tick_start();
607-
608-
UR_ASSERT(hQueue, UR_RESULT_ERROR_INVALID_QUEUE);
609-
UR_ASSERT(pDst, UR_RESULT_ERROR_INVALID_NULL_POINTER);
610-
UR_ASSERT(pSrc, UR_RESULT_ERROR_INVALID_NULL_POINTER);
608+
return withTimingEvent(
609+
UR_COMMAND_USM_MEMCPY, hQueue, numEventsInWaitList, phEventWaitList,
610+
phEvent, [&]() {
611+
UR_ASSERT(hQueue, UR_RESULT_ERROR_INVALID_QUEUE);
612+
UR_ASSERT(pDst, UR_RESULT_ERROR_INVALID_NULL_POINTER);
613+
UR_ASSERT(pSrc, UR_RESULT_ERROR_INVALID_NULL_POINTER);
611614

612-
memcpy(pDst, pSrc, size);
613-
event->tick_end();
614-
*phEvent = event;
615+
memcpy(pDst, pSrc, size);
615616

616-
return UR_RESULT_SUCCESS;
617+
return UR_RESULT_SUCCESS;
618+
});
617619
}
618620

619621
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMPrefetch(

0 commit comments

Comments
 (0)