12
12
13
13
#include " common.hpp"
14
14
#include " context.hpp"
15
+ #include " enqueue.hpp"
15
16
#include " memory.hpp"
16
17
17
18
// / Creates a UR Memory object using a CUDA memory allocation.
@@ -238,7 +239,12 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemImageCreate(
238
239
try {
239
240
if (PerformInitialCopy) {
240
241
for (const auto &Device : hContext->getDevices ()) {
241
- UR_CHECK_ERROR (migrateMemoryToDeviceIfNeeded (URMemObj.get (), Device));
242
+ // Synchronous behaviour is best in this case
243
+ ScopedContext Active (Device);
244
+ CUstream Stream{0 }; // Use default stream
245
+ UR_CHECK_ERROR (enqueueMigrateMemoryToDeviceIfNeeded (URMemObj.get (),
246
+ Device, Stream));
247
+ UR_CHECK_ERROR (cuStreamSynchronize (Stream));
242
248
}
243
249
}
244
250
@@ -496,27 +502,28 @@ ur_result_t allocateMemObjOnDeviceIfNeeded(ur_mem_handle_t Mem,
496
502
}
497
503
498
504
namespace {
499
- ur_result_t migrateBufferToDevice (ur_mem_handle_t Mem,
500
- ur_device_handle_t hDevice) {
505
+ ur_result_t enqueueMigrateBufferToDevice (ur_mem_handle_t Mem,
506
+ ur_device_handle_t hDevice,
507
+ CUstream Stream) {
501
508
auto &Buffer = std::get<BufferMem>(Mem->Mem );
502
- if (Mem->LastEventWritingToMemObj == nullptr ) {
509
+ if (Mem->LastQueueWritingToMemObj == nullptr ) {
503
510
// Device allocation being initialized from host for the first time
504
511
if (Buffer.HostPtr ) {
505
- UR_CHECK_ERROR (
506
- cuMemcpyHtoD (Buffer. getPtr (hDevice), Buffer. HostPtr , Buffer.Size ));
512
+ UR_CHECK_ERROR (cuMemcpyHtoDAsync (Buffer. getPtr (hDevice), Buffer. HostPtr ,
513
+ Buffer.Size , Stream ));
507
514
}
508
- } else if (Mem->LastEventWritingToMemObj ->getQueue ()->getDevice () !=
509
- hDevice) {
510
- UR_CHECK_ERROR (cuMemcpyDtoD (
515
+ } else if (Mem->LastQueueWritingToMemObj ->getDevice () != hDevice) {
516
+ UR_CHECK_ERROR (cuMemcpyDtoDAsync (
511
517
Buffer.getPtr (hDevice),
512
- Buffer.getPtr (Mem->LastEventWritingToMemObj -> getQueue ()-> getDevice ()),
513
- Buffer. Size ));
518
+ Buffer.getPtr (Mem->LastQueueWritingToMemObj -> getDevice ()), Buffer. Size ,
519
+ Stream ));
514
520
}
515
521
return UR_RESULT_SUCCESS;
516
522
}
517
523
518
- ur_result_t migrateImageToDevice (ur_mem_handle_t Mem,
519
- ur_device_handle_t hDevice) {
524
+ ur_result_t enqueueMigrateImageToDevice (ur_mem_handle_t Mem,
525
+ ur_device_handle_t hDevice,
526
+ CUstream Stream) {
520
527
auto &Image = std::get<SurfaceMem>(Mem->Mem );
521
528
// When a dimension isn't used image_desc has the size set to 1
522
529
size_t PixelSizeBytes = Image.PixelTypeSizeBytes *
@@ -547,40 +554,42 @@ ur_result_t migrateImageToDevice(ur_mem_handle_t Mem,
547
554
CpyDesc3D.Depth = Image.ImageDesc .depth ;
548
555
}
549
556
550
- if (Mem->LastEventWritingToMemObj == nullptr ) {
557
+ if (Mem->LastQueueWritingToMemObj == nullptr ) {
551
558
if (Image.HostPtr ) {
552
559
if (Image.ImageDesc .type == UR_MEM_TYPE_IMAGE1D) {
553
- UR_CHECK_ERROR (
554
- cuMemcpyHtoA (ImageArray, 0 , Image. HostPtr , ImageSizeBytes));
560
+ UR_CHECK_ERROR (cuMemcpyHtoAAsync (ImageArray, 0 , Image. HostPtr ,
561
+ ImageSizeBytes, Stream ));
555
562
} else if (Image.ImageDesc .type == UR_MEM_TYPE_IMAGE2D) {
556
563
CpyDesc2D.srcMemoryType = CUmemorytype_enum::CU_MEMORYTYPE_HOST;
557
564
CpyDesc2D.srcHost = Image.HostPtr ;
558
- UR_CHECK_ERROR (cuMemcpy2D (&CpyDesc2D));
565
+ UR_CHECK_ERROR (cuMemcpy2DAsync (&CpyDesc2D, Stream ));
559
566
} else if (Image.ImageDesc .type == UR_MEM_TYPE_IMAGE3D) {
560
567
CpyDesc3D.srcMemoryType = CUmemorytype_enum::CU_MEMORYTYPE_HOST;
561
568
CpyDesc3D.srcHost = Image.HostPtr ;
562
- UR_CHECK_ERROR (cuMemcpy3D (&CpyDesc3D));
569
+ UR_CHECK_ERROR (cuMemcpy3DAsync (&CpyDesc3D, Stream ));
563
570
}
564
571
}
565
- } else if (Mem->LastEventWritingToMemObj ->getQueue ()->getDevice () !=
566
- hDevice) {
572
+ } else if (Mem->LastQueueWritingToMemObj ->getDevice () != hDevice) {
567
573
if (Image.ImageDesc .type == UR_MEM_TYPE_IMAGE1D) {
574
+ // Blocking wait needed
575
+ UR_CHECK_ERROR (urQueueFinish (Mem->LastQueueWritingToMemObj ));
568
576
// FIXME: 1D memcpy from DtoD going through the host.
569
577
UR_CHECK_ERROR (cuMemcpyAtoH (
570
578
Image.HostPtr ,
571
- Image.getArray (
572
- Mem->LastEventWritingToMemObj ->getQueue ()->getDevice ()),
579
+ Image.getArray (Mem->LastQueueWritingToMemObj ->getDevice ()),
573
580
0 /* srcOffset*/ , ImageSizeBytes));
574
581
UR_CHECK_ERROR (
575
582
cuMemcpyHtoA (ImageArray, 0 , Image.HostPtr , ImageSizeBytes));
576
583
} else if (Image.ImageDesc .type == UR_MEM_TYPE_IMAGE2D) {
577
- CpyDesc2D.srcArray = Image.getArray (
578
- Mem->LastEventWritingToMemObj ->getQueue ()->getDevice ());
579
- UR_CHECK_ERROR (cuMemcpy2D (&CpyDesc2D));
584
+ CpyDesc2D.srcMemoryType = CUmemorytype_enum::CU_MEMORYTYPE_DEVICE;
585
+ CpyDesc2D.srcArray =
586
+ Image.getArray (Mem->LastQueueWritingToMemObj ->getDevice ());
587
+ UR_CHECK_ERROR (cuMemcpy2DAsync (&CpyDesc2D, Stream));
580
588
} else if (Image.ImageDesc .type == UR_MEM_TYPE_IMAGE3D) {
581
- CpyDesc3D.srcArray = Image.getArray (
582
- Mem->LastEventWritingToMemObj ->getQueue ()->getDevice ());
583
- UR_CHECK_ERROR (cuMemcpy3D (&CpyDesc3D));
589
+ CpyDesc3D.srcMemoryType = CUmemorytype_enum::CU_MEMORYTYPE_DEVICE;
590
+ CpyDesc3D.srcArray =
591
+ Image.getArray (Mem->LastQueueWritingToMemObj ->getDevice ());
592
+ UR_CHECK_ERROR (cuMemcpy3DAsync (&CpyDesc3D, Stream));
584
593
}
585
594
}
586
595
return UR_RESULT_SUCCESS;
@@ -589,8 +598,8 @@ ur_result_t migrateImageToDevice(ur_mem_handle_t Mem,
589
598
590
599
// If calling this entry point it is necessary to lock the memoryMigrationMutex
591
600
// beforehand
592
- ur_result_t migrateMemoryToDeviceIfNeeded ( ur_mem_handle_t Mem,
593
- const ur_device_handle_t hDevice) {
601
+ ur_result_t enqueueMigrateMemoryToDeviceIfNeeded (
602
+ ur_mem_handle_t Mem, const ur_device_handle_t hDevice, CUstream Stream ) {
594
603
UR_ASSERT (hDevice, UR_RESULT_ERROR_INVALID_NULL_HANDLE);
595
604
// Device allocation has already been initialized with most up to date
596
605
// data in buffer
@@ -601,9 +610,9 @@ ur_result_t migrateMemoryToDeviceIfNeeded(ur_mem_handle_t Mem,
601
610
602
611
ScopedContext Active (hDevice);
603
612
if (Mem->isBuffer ()) {
604
- UR_CHECK_ERROR (migrateBufferToDevice (Mem, hDevice));
613
+ UR_CHECK_ERROR (enqueueMigrateBufferToDevice (Mem, hDevice, Stream ));
605
614
} else {
606
- UR_CHECK_ERROR (migrateImageToDevice (Mem, hDevice));
615
+ UR_CHECK_ERROR (enqueueMigrateImageToDevice (Mem, hDevice, Stream ));
607
616
}
608
617
609
618
Mem->HaveMigratedToDeviceSinceLastWrite [Mem->getContext ()->getDeviceIndex (
0 commit comments