@@ -258,26 +258,43 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
258
258
return UR_RESULT_SUCCESS;
259
259
}
260
260
261
+ ur_result_t withTimingEvent (ur_command_t command_type, ur_queue_handle_t hQueue,
262
+ uint32_t numEventsInWaitList,
263
+ const ur_event_handle_t *phEventWaitList,
264
+ ur_event_handle_t *phEvent,
265
+ const std::function<ur_result_t ()> &f) {
266
+ urEventWait (numEventsInWaitList, phEventWaitList);
267
+ ur_event_handle_t event;
268
+ if (phEvent) {
269
+ event = new ur_event_handle_t_ (hQueue, command_type);
270
+ event->tick_start ();
271
+ }
272
+
273
+ ur_result_t result = f ();
274
+
275
+ if (phEvent) {
276
+ event->tick_end ();
277
+ *phEvent = event;
278
+ }
279
+ return result;
280
+ }
281
+
261
282
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueEventsWait (
262
283
ur_queue_handle_t hQueue, uint32_t numEventsInWaitList,
263
284
const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
264
- std::ignore = hQueue;
265
- std::ignore = numEventsInWaitList;
266
- std::ignore = phEventWaitList;
267
- std::ignore = phEvent;
268
285
269
- DIE_NO_IMPLEMENTATION;
286
+ // TODO: the wait here should be async
287
+ return withTimingEvent (UR_COMMAND_EVENTS_WAIT, hQueue, numEventsInWaitList,
288
+ phEventWaitList, phEvent,
289
+ [&]() { return UR_RESULT_SUCCESS; });
270
290
}
271
291
272
292
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueEventsWaitWithBarrier (
273
293
ur_queue_handle_t hQueue, uint32_t numEventsInWaitList,
274
294
const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
275
- std::ignore = hQueue;
276
- std::ignore = numEventsInWaitList;
277
- std::ignore = phEventWaitList;
278
- std::ignore = phEvent;
279
-
280
- DIE_NO_IMPLEMENTATION;
295
+ return withTimingEvent (UR_COMMAND_EVENTS_WAIT_WITH_BARRIER, hQueue,
296
+ numEventsInWaitList, phEventWaitList, phEvent,
297
+ [&]() { return UR_RESULT_SUCCESS; });
281
298
}
282
299
283
300
template <bool IsRead>
@@ -289,43 +306,42 @@ static inline ur_result_t enqueueMemBufferReadWriteRect_impl(
289
306
typename std::conditional<IsRead, void *, const void *>::type DstMem,
290
307
uint32_t NumEventsInWaitList, const ur_event_handle_t *phEventWaitList,
291
308
ur_event_handle_t *phEvent) {
292
- ur_event_handle_t event ;
309
+ ur_command_t command_t ;
293
310
if constexpr (IsRead)
294
- event = new ur_event_handle_t_ (hQueue, UR_COMMAND_MEM_BUFFER_READ_RECT) ;
311
+ command_t = UR_COMMAND_MEM_BUFFER_READ_RECT;
295
312
else
296
- event = new ur_event_handle_t_ (hQueue, UR_COMMAND_MEM_BUFFER_WRITE_RECT) ;
297
- event-> tick_start ();
298
- // TODO: blocking, check other constraints, performance optimizations
299
- // More sharing with level_zero where possible
300
-
301
- urEventWait (NumEventsInWaitList, phEventWaitList);
302
- if (BufferRowPitch == 0 )
303
- BufferRowPitch = region.width ;
304
- if (BufferSlicePitch == 0 )
305
- BufferSlicePitch = BufferRowPitch * region.height ;
306
- if (HostRowPitch == 0 )
307
- HostRowPitch = region.width ;
308
- if (HostSlicePitch == 0 )
309
- HostSlicePitch = HostRowPitch * region.height ;
310
- for (size_t w = 0 ; w < region.width ; w++)
311
- for (size_t h = 0 ; h < region.height ; h++)
312
- for (size_t d = 0 ; d < region.depth ; d++) {
313
- size_t buff_orign = (d + BufferOffset.z ) * BufferSlicePitch +
314
- (h + BufferOffset.y ) * BufferRowPitch + w +
315
- BufferOffset.x ;
316
- size_t host_origin = (d + HostOffset.z ) * HostSlicePitch +
317
- (h + HostOffset.y ) * HostRowPitch + w +
318
- HostOffset.x ;
319
- int8_t &buff_mem = ur_cast<int8_t *>(Buff->_mem )[buff_orign];
320
- if constexpr (IsRead)
321
- ur_cast<int8_t *>(DstMem)[host_origin] = buff_mem;
322
- else
323
- buff_mem = ur_cast<const int8_t *>(DstMem)[host_origin];
324
- }
313
+ command_t = UR_COMMAND_MEM_BUFFER_WRITE_RECT;
314
+ return withTimingEvent (
315
+ command_t , hQueue, NumEventsInWaitList, phEventWaitList, phEvent, [&]() {
316
+ // TODO: blocking, check other constraints, performance optimizations
317
+ // More sharing with level_zero where possible
318
+
319
+ if (BufferRowPitch == 0 )
320
+ BufferRowPitch = region.width ;
321
+ if (BufferSlicePitch == 0 )
322
+ BufferSlicePitch = BufferRowPitch * region.height ;
323
+ if (HostRowPitch == 0 )
324
+ HostRowPitch = region.width ;
325
+ if (HostSlicePitch == 0 )
326
+ HostSlicePitch = HostRowPitch * region.height ;
327
+ for (size_t w = 0 ; w < region.width ; w++)
328
+ for (size_t h = 0 ; h < region.height ; h++)
329
+ for (size_t d = 0 ; d < region.depth ; d++) {
330
+ size_t buff_orign = (d + BufferOffset.z ) * BufferSlicePitch +
331
+ (h + BufferOffset.y ) * BufferRowPitch + w +
332
+ BufferOffset.x ;
333
+ size_t host_origin = (d + HostOffset.z ) * HostSlicePitch +
334
+ (h + HostOffset.y ) * HostRowPitch + w +
335
+ HostOffset.x ;
336
+ int8_t &buff_mem = ur_cast<int8_t *>(Buff->_mem )[buff_orign];
337
+ if constexpr (IsRead)
338
+ ur_cast<int8_t *>(DstMem)[host_origin] = buff_mem;
339
+ else
340
+ buff_mem = ur_cast<const int8_t *>(DstMem)[host_origin];
341
+ }
325
342
326
- event->tick_end ();
327
- *phEvent = event;
328
- return UR_RESULT_SUCCESS;
343
+ return UR_RESULT_SUCCESS;
344
+ });
329
345
}
330
346
331
347
static inline ur_result_t doCopy_impl (ur_queue_handle_t hQueue, void *DstPtr,
@@ -334,15 +350,12 @@ static inline ur_result_t doCopy_impl(ur_queue_handle_t hQueue, void *DstPtr,
334
350
const ur_event_handle_t *phEventWaitList,
335
351
ur_event_handle_t *phEvent,
336
352
ur_command_t command_type) {
337
- ur_event_handle_t event = new ur_event_handle_t_ (hQueue, command_type);
338
- event->tick_start ();
339
- urEventWait (numEventsInWaitList, phEventWaitList);
340
- if (SrcPtr != DstPtr && Size)
341
- memmove (DstPtr, SrcPtr, Size);
342
- event->tick_end ();
343
- if (phEvent)
344
- *phEvent = event;
345
- return UR_RESULT_SUCCESS;
353
+ return withTimingEvent (command_type, hQueue, numEventsInWaitList,
354
+ phEventWaitList, phEvent, [&]() {
355
+ if (SrcPtr != DstPtr && Size)
356
+ memmove (DstPtr, SrcPtr, Size);
357
+ return UR_RESULT_SUCCESS;
358
+ });
346
359
}
347
360
348
361
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferRead (
@@ -426,22 +439,23 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferFill(
426
439
size_t patternSize, size_t offset, size_t size,
427
440
uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
428
441
ur_event_handle_t *phEvent) {
429
- std::ignore = numEventsInWaitList;
430
- std::ignore = phEventWaitList;
431
- std::ignore = phEvent;
432
-
433
- UR_ASSERT (hQueue, UR_RESULT_ERROR_INVALID_NULL_HANDLE);
434
442
435
- // TODO: error checking
436
- // TODO: handle async
437
- void *startingPtr = hBuffer->_mem + offset;
438
- unsigned steps = size / patternSize;
439
- for (unsigned i = 0 ; i < steps; i++) {
440
- memcpy (static_cast <int8_t *>(startingPtr) + i * patternSize, pPattern,
441
- patternSize);
442
- }
443
+ return withTimingEvent (
444
+ UR_COMMAND_MEM_BUFFER_FILL, hQueue, numEventsInWaitList, phEventWaitList,
445
+ phEvent, [&]() {
446
+ UR_ASSERT (hQueue, UR_RESULT_ERROR_INVALID_NULL_HANDLE);
447
+
448
+ // TODO: error checking
449
+ // TODO: handle async
450
+ void *startingPtr = hBuffer->_mem + offset;
451
+ unsigned steps = size / patternSize;
452
+ for (unsigned i = 0 ; i < steps; i++) {
453
+ memcpy (static_cast <int8_t *>(startingPtr) + i * patternSize, pPattern,
454
+ patternSize);
455
+ }
443
456
444
- return UR_RESULT_SUCCESS;
457
+ return UR_RESULT_SUCCESS;
458
+ });
445
459
}
446
460
447
461
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageRead (
@@ -512,15 +526,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferMap(
512
526
std::ignore = mapFlags;
513
527
std::ignore = size;
514
528
515
- urEventWait (numEventsInWaitList, phEventWaitList);
516
- ur_event_handle_t event =
517
- new ur_event_handle_t_ (hQueue, UR_COMMAND_MEM_BUFFER_MAP);
518
- event->tick_start ();
519
- *ppRetMap = hBuffer->_mem + offset;
520
- event->tick_end ();
521
- *phEvent = event;
522
-
523
- return UR_RESULT_SUCCESS;
529
+ return withTimingEvent (UR_COMMAND_MEM_BUFFER_MAP, hQueue, numEventsInWaitList,
530
+ phEventWaitList, phEvent, [&]() {
531
+ *ppRetMap = hBuffer->_mem + offset;
532
+ return UR_RESULT_SUCCESS;
533
+ });
524
534
}
525
535
526
536
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemUnmap (
@@ -529,91 +539,83 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemUnmap(
529
539
ur_event_handle_t *phEvent) {
530
540
std::ignore = hMem;
531
541
std::ignore = pMappedPtr;
532
- urEventWait (numEventsInWaitList, phEventWaitList);
533
- *phEvent = new ur_event_handle_t_ (hQueue, UR_COMMAND_MEM_UNMAP);
534
-
535
- return UR_RESULT_SUCCESS;
542
+ return withTimingEvent (UR_COMMAND_MEM_UNMAP, hQueue, numEventsInWaitList,
543
+ phEventWaitList, phEvent,
544
+ [&]() { return UR_RESULT_SUCCESS; });
536
545
}
537
546
538
547
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMFill (
539
548
ur_queue_handle_t hQueue, void *ptr, size_t patternSize,
540
549
const void *pPattern, size_t size, uint32_t numEventsInWaitList,
541
550
const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
542
- urEventWait (numEventsInWaitList, phEventWaitList);
543
- ur_event_handle_t event =
544
- new ur_event_handle_t_ (hQueue, UR_COMMAND_MEM_BUFFER_MAP);
545
- event->tick_start ();
546
-
547
- UR_ASSERT (ptr, UR_RESULT_ERROR_INVALID_NULL_POINTER);
548
- UR_ASSERT (pPattern, UR_RESULT_ERROR_INVALID_NULL_POINTER);
549
- UR_ASSERT (patternSize != 0 , UR_RESULT_ERROR_INVALID_SIZE)
550
- UR_ASSERT (size != 0 , UR_RESULT_ERROR_INVALID_SIZE)
551
- UR_ASSERT (patternSize < size, UR_RESULT_ERROR_INVALID_SIZE)
552
- UR_ASSERT (size % patternSize == 0 , UR_RESULT_ERROR_INVALID_SIZE)
553
- // TODO: add check for allocation size once the query is supported
554
-
555
- switch (patternSize) {
556
- case 1 :
557
- memset (ptr, *static_cast <const uint8_t *>(pPattern), size * patternSize);
558
- break ;
559
- case 2 : {
560
- const auto pattern = *static_cast <const uint16_t *>(pPattern);
561
- auto *start = reinterpret_cast <uint16_t *>(ptr);
562
- auto *end =
563
- reinterpret_cast <uint16_t *>(reinterpret_cast <uint8_t *>(ptr) + size);
564
- std::fill (start, end, pattern);
565
- break ;
566
- }
567
- case 4 : {
568
- const auto pattern = *static_cast <const uint32_t *>(pPattern);
569
- auto *start = reinterpret_cast <uint32_t *>(ptr);
570
- auto *end =
571
- reinterpret_cast <uint32_t *>(reinterpret_cast <uint8_t *>(ptr) + size);
572
- std::fill (start, end, pattern);
573
- break ;
574
- }
575
- case 8 : {
576
- const auto pattern = *static_cast <const uint64_t *>(pPattern);
577
- auto *start = reinterpret_cast <uint64_t *>(ptr);
578
- auto *end =
579
- reinterpret_cast <uint64_t *>(reinterpret_cast <uint8_t *>(ptr) + size);
580
- std::fill (start, end, pattern);
581
- break ;
582
- }
583
- default : {
584
- for (unsigned int step{0 }; step < size; step += patternSize) {
585
- auto *dest =
586
- reinterpret_cast <void *>(reinterpret_cast <uint8_t *>(ptr) + step);
587
- memcpy (dest, pPattern, patternSize);
588
- }
589
- }
590
- }
591
-
592
- event->tick_end ();
593
- *phEvent = event;
594
-
595
- return UR_RESULT_SUCCESS;
551
+ return withTimingEvent (
552
+ UR_COMMAND_USM_FILL, hQueue, numEventsInWaitList, phEventWaitList,
553
+ phEvent, [&]() {
554
+ UR_ASSERT (ptr, UR_RESULT_ERROR_INVALID_NULL_POINTER);
555
+ UR_ASSERT (pPattern, UR_RESULT_ERROR_INVALID_NULL_POINTER);
556
+ UR_ASSERT (patternSize != 0 , UR_RESULT_ERROR_INVALID_SIZE)
557
+ UR_ASSERT (size != 0 , UR_RESULT_ERROR_INVALID_SIZE)
558
+ UR_ASSERT (patternSize < size, UR_RESULT_ERROR_INVALID_SIZE)
559
+ UR_ASSERT (size % patternSize == 0 , UR_RESULT_ERROR_INVALID_SIZE)
560
+ // TODO: add check for allocation size once the query is supported
561
+
562
+ switch (patternSize) {
563
+ case 1 :
564
+ memset (ptr, *static_cast <const uint8_t *>(pPattern),
565
+ size * patternSize);
566
+ break ;
567
+ case 2 : {
568
+ const auto pattern = *static_cast <const uint16_t *>(pPattern);
569
+ auto *start = reinterpret_cast <uint16_t *>(ptr);
570
+ auto *end = reinterpret_cast <uint16_t *>(
571
+ reinterpret_cast <uint8_t *>(ptr) + size);
572
+ std::fill (start, end, pattern);
573
+ break ;
574
+ }
575
+ case 4 : {
576
+ const auto pattern = *static_cast <const uint32_t *>(pPattern);
577
+ auto *start = reinterpret_cast <uint32_t *>(ptr);
578
+ auto *end = reinterpret_cast <uint32_t *>(
579
+ reinterpret_cast <uint8_t *>(ptr) + size);
580
+ std::fill (start, end, pattern);
581
+ break ;
582
+ }
583
+ case 8 : {
584
+ const auto pattern = *static_cast <const uint64_t *>(pPattern);
585
+ auto *start = reinterpret_cast <uint64_t *>(ptr);
586
+ auto *end = reinterpret_cast <uint64_t *>(
587
+ reinterpret_cast <uint8_t *>(ptr) + size);
588
+ std::fill (start, end, pattern);
589
+ break ;
590
+ }
591
+ default : {
592
+ for (unsigned int step{0 }; step < size; step += patternSize) {
593
+ auto *dest = reinterpret_cast <void *>(
594
+ reinterpret_cast <uint8_t *>(ptr) + step);
595
+ memcpy (dest, pPattern, patternSize);
596
+ }
597
+ }
598
+ }
599
+ return UR_RESULT_SUCCESS;
600
+ });
596
601
}
597
602
598
603
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMMemcpy (
599
604
ur_queue_handle_t hQueue, bool blocking, void *pDst, const void *pSrc,
600
605
size_t size, uint32_t numEventsInWaitList,
601
606
const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
602
607
std::ignore = blocking;
603
- urEventWait (numEventsInWaitList, phEventWaitList);
604
- ur_event_handle_t event =
605
- new ur_event_handle_t_ (hQueue, UR_COMMAND_MEM_BUFFER_MAP);
606
- event->tick_start ();
607
-
608
- UR_ASSERT (hQueue, UR_RESULT_ERROR_INVALID_QUEUE);
609
- UR_ASSERT (pDst, UR_RESULT_ERROR_INVALID_NULL_POINTER);
610
- UR_ASSERT (pSrc, UR_RESULT_ERROR_INVALID_NULL_POINTER);
608
+ return withTimingEvent (
609
+ UR_COMMAND_USM_MEMCPY, hQueue, numEventsInWaitList, phEventWaitList,
610
+ phEvent, [&]() {
611
+ UR_ASSERT (hQueue, UR_RESULT_ERROR_INVALID_QUEUE);
612
+ UR_ASSERT (pDst, UR_RESULT_ERROR_INVALID_NULL_POINTER);
613
+ UR_ASSERT (pSrc, UR_RESULT_ERROR_INVALID_NULL_POINTER);
611
614
612
- memcpy (pDst, pSrc, size);
613
- event->tick_end ();
614
- *phEvent = event;
615
+ memcpy (pDst, pSrc, size);
615
616
616
- return UR_RESULT_SUCCESS;
617
+ return UR_RESULT_SUCCESS;
618
+ });
617
619
}
618
620
619
621
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMPrefetch (
0 commit comments