@@ -514,7 +514,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferReadRect(
514
514
uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
515
515
ur_event_handle_t *phEvent) {
516
516
ur_result_t Result = UR_RESULT_SUCCESS;
517
- CUdeviceptr DevPtr = hBuffer->Mem . BufferMem .get ();
517
+ CUdeviceptr DevPtr = std::get<BufferMem>( hBuffer->Mem ) .get ();
518
518
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr };
519
519
520
520
try {
@@ -562,7 +562,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWriteRect(
562
562
uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
563
563
ur_event_handle_t *phEvent) {
564
564
ur_result_t Result = UR_RESULT_SUCCESS;
565
- CUdeviceptr DevPtr = hBuffer->Mem . BufferMem .get ();
565
+ CUdeviceptr DevPtr = std::get<BufferMem>( hBuffer->Mem ) .get ();
566
566
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr };
567
567
568
568
try {
@@ -606,9 +606,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferCopy(
606
606
ur_mem_handle_t hBufferDst, size_t srcOffset, size_t dstOffset, size_t size,
607
607
uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
608
608
ur_event_handle_t *phEvent) {
609
- UR_ASSERT (size + dstOffset <= hBufferDst->Mem . BufferMem .getSize (),
609
+ UR_ASSERT (size + dstOffset <= std::get<BufferMem>( hBufferDst->Mem ) .getSize (),
610
610
UR_RESULT_ERROR_INVALID_SIZE);
611
- UR_ASSERT (size + srcOffset <= hBufferSrc->Mem . BufferMem .getSize (),
611
+ UR_ASSERT (size + srcOffset <= std::get<BufferMem>( hBufferSrc->Mem ) .getSize (),
612
612
UR_RESULT_ERROR_INVALID_SIZE);
613
613
614
614
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr };
@@ -628,8 +628,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferCopy(
628
628
UR_CHECK_ERROR (RetImplEvent->start ());
629
629
}
630
630
631
- auto Src = hBufferSrc->Mem . BufferMem .get () + srcOffset;
632
- auto Dst = hBufferDst->Mem . BufferMem .get () + dstOffset;
631
+ auto Src = std::get<BufferMem>( hBufferSrc->Mem ) .get () + srcOffset;
632
+ auto Dst = std::get<BufferMem>( hBufferDst->Mem ) .get () + dstOffset;
633
633
634
634
UR_CHECK_ERROR (cuMemcpyDtoDAsync (Dst, Src, size, Stream));
635
635
@@ -654,8 +654,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferCopyRect(
654
654
uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
655
655
ur_event_handle_t *phEvent) {
656
656
ur_result_t Result = UR_RESULT_SUCCESS;
657
- CUdeviceptr SrcPtr = hBufferSrc->Mem . BufferMem .get ();
658
- CUdeviceptr DstPtr = hBufferDst->Mem . BufferMem .get ();
657
+ CUdeviceptr SrcPtr = std::get<BufferMem>( hBufferSrc->Mem ) .get ();
658
+ CUdeviceptr DstPtr = std::get<BufferMem>( hBufferDst->Mem ) .get ();
659
659
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr };
660
660
661
661
try {
@@ -726,7 +726,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferFill(
726
726
size_t patternSize, size_t offset, size_t size,
727
727
uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
728
728
ur_event_handle_t *phEvent) {
729
- UR_ASSERT (size + offset <= hBuffer->Mem . BufferMem .getSize (),
729
+ UR_ASSERT (size + offset <= std::get<BufferMem>( hBuffer->Mem ) .getSize (),
730
730
UR_RESULT_ERROR_INVALID_SIZE);
731
731
732
732
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr };
@@ -745,7 +745,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferFill(
745
745
UR_CHECK_ERROR (RetImplEvent->start ());
746
746
}
747
747
748
- auto DstDevice = hBuffer->Mem . BufferMem .get () + offset;
748
+ auto DstDevice = std::get<BufferMem>( hBuffer->Mem ) .get () + offset;
749
749
auto N = size / patternSize;
750
750
751
751
// pattern size in bytes
@@ -892,7 +892,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageRead(
892
892
Result = enqueueEventsWait (hQueue, CuStream, numEventsInWaitList,
893
893
phEventWaitList);
894
894
895
- CUarray Array = hImage->Mem . SurfaceMem .getArray ();
895
+ CUarray Array = std::get<SurfaceMem>( hImage->Mem ) .getArray ();
896
896
897
897
CUDA_ARRAY_DESCRIPTOR ArrayDesc;
898
898
UR_CHECK_ERROR (cuArrayGetDescriptor (&ArrayDesc, Array));
@@ -902,7 +902,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageRead(
902
902
size_t ByteOffsetX = origin.x * ElementByteSize * ArrayDesc.NumChannels ;
903
903
size_t BytesToCopy = ElementByteSize * ArrayDesc.NumChannels * region.width ;
904
904
905
- ur_mem_type_t ImgType = hImage->Mem . SurfaceMem .getImageType ();
905
+ ur_mem_type_t ImgType = std::get<SurfaceMem>( hImage->Mem ) .getImageType ();
906
906
907
907
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr };
908
908
if (phEvent) {
@@ -964,7 +964,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageWrite(
964
964
Result = enqueueEventsWait (hQueue, CuStream, numEventsInWaitList,
965
965
phEventWaitList);
966
966
967
- CUarray Array = hImage->Mem . SurfaceMem .getArray ();
967
+ CUarray Array = std::get<SurfaceMem>( hImage->Mem ) .getArray ();
968
968
969
969
CUDA_ARRAY_DESCRIPTOR ArrayDesc;
970
970
UR_CHECK_ERROR (cuArrayGetDescriptor (&ArrayDesc, Array));
@@ -982,7 +982,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageWrite(
982
982
UR_CHECK_ERROR (RetImplEvent->start ());
983
983
}
984
984
985
- ur_mem_type_t ImgType = hImage->Mem . SurfaceMem .getImageType ();
985
+ ur_mem_type_t ImgType = std::get<SurfaceMem>( hImage->Mem ) .getImageType ();
986
986
if (ImgType == UR_MEM_TYPE_IMAGE1D) {
987
987
UR_CHECK_ERROR (
988
988
cuMemcpyHtoAAsync (Array, ByteOffsetX, pSrc, BytesToCopy, CuStream));
@@ -1023,8 +1023,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageCopy(
1023
1023
UR_RESULT_ERROR_INVALID_MEM_OBJECT);
1024
1024
UR_ASSERT (hImageDst->MemType == ur_mem_handle_t_::Type::Surface,
1025
1025
UR_RESULT_ERROR_INVALID_MEM_OBJECT);
1026
- UR_ASSERT (hImageSrc->Mem . SurfaceMem .getImageType () ==
1027
- hImageDst->Mem . SurfaceMem .getImageType (),
1026
+ UR_ASSERT (std::get<SurfaceMem>( hImageSrc->Mem ) .getImageType () ==
1027
+ std::get<SurfaceMem>( hImageDst->Mem ) .getImageType (),
1028
1028
UR_RESULT_ERROR_INVALID_MEM_OBJECT);
1029
1029
1030
1030
ur_result_t Result = UR_RESULT_SUCCESS;
@@ -1035,8 +1035,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageCopy(
1035
1035
Result = enqueueEventsWait (hQueue, CuStream, numEventsInWaitList,
1036
1036
phEventWaitList);
1037
1037
1038
- CUarray SrcArray = hImageSrc->Mem . SurfaceMem .getArray ();
1039
- CUarray DstArray = hImageDst->Mem . SurfaceMem .getArray ();
1038
+ CUarray SrcArray = std::get<SurfaceMem>( hImageSrc->Mem ) .getArray ();
1039
+ CUarray DstArray = std::get<SurfaceMem>( hImageDst->Mem ) .getArray ();
1040
1040
1041
1041
CUDA_ARRAY_DESCRIPTOR SrcArrayDesc;
1042
1042
UR_CHECK_ERROR (cuArrayGetDescriptor (&SrcArrayDesc, SrcArray));
@@ -1065,7 +1065,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageCopy(
1065
1065
UR_CHECK_ERROR (RetImplEvent->start ());
1066
1066
}
1067
1067
1068
- ur_mem_type_t ImgType = hImageSrc->Mem . SurfaceMem .getImageType ();
1068
+ ur_mem_type_t ImgType = std::get<SurfaceMem>( hImageSrc->Mem ) .getImageType ();
1069
1069
if (ImgType == UR_MEM_TYPE_IMAGE1D) {
1070
1070
UR_CHECK_ERROR (cuMemcpyAtoA (DstArray, DstByteOffsetX, SrcArray,
1071
1071
SrcByteOffsetX, BytesToCopy));
@@ -1108,22 +1108,22 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferMap(
1108
1108
ur_event_handle_t *phEvent, void **ppRetMap) {
1109
1109
UR_ASSERT (hBuffer->MemType == ur_mem_handle_t_::Type::Buffer,
1110
1110
UR_RESULT_ERROR_INVALID_MEM_OBJECT);
1111
- UR_ASSERT (offset + size <= hBuffer->Mem . BufferMem .getSize (),
1111
+ UR_ASSERT (offset + size <= std::get<BufferMem>( hBuffer->Mem ) .getSize (),
1112
1112
UR_RESULT_ERROR_INVALID_SIZE);
1113
1113
1114
+ auto &BufferImpl = std::get<BufferMem>(hBuffer->Mem );
1114
1115
ur_result_t Result = UR_RESULT_ERROR_INVALID_MEM_OBJECT;
1115
1116
const bool IsPinned =
1116
- hBuffer->Mem .BufferMem .MemAllocMode ==
1117
- ur_mem_handle_t_::MemImpl::BufferMem::AllocMode::AllocHostPtr;
1117
+ BufferImpl.MemAllocMode == BufferMem::AllocMode::AllocHostPtr;
1118
1118
1119
1119
// Currently no support for overlapping regions
1120
- if (hBuffer-> Mem . BufferMem .getMapPtr () != nullptr ) {
1120
+ if (BufferImpl .getMapPtr () != nullptr ) {
1121
1121
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
1122
1122
}
1123
1123
1124
1124
// Allocate a pointer in the host to store the mapped information
1125
- auto HostPtr = hBuffer-> Mem . BufferMem .mapToPtr (size, offset, mapFlags);
1126
- *ppRetMap = hBuffer-> Mem . BufferMem .getMapPtr ();
1125
+ auto HostPtr = BufferImpl .mapToPtr (size, offset, mapFlags);
1126
+ *ppRetMap = BufferImpl .getMapPtr ();
1127
1127
if (HostPtr) {
1128
1128
Result = UR_RESULT_SUCCESS;
1129
1129
}
@@ -1168,21 +1168,21 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemUnmap(
1168
1168
ur_result_t Result = UR_RESULT_SUCCESS;
1169
1169
UR_ASSERT (hMem->MemType == ur_mem_handle_t_::Type::Buffer,
1170
1170
UR_RESULT_ERROR_INVALID_MEM_OBJECT);
1171
- UR_ASSERT (hMem->Mem . BufferMem .getMapPtr () != nullptr ,
1171
+ UR_ASSERT (std::get<BufferMem>( hMem->Mem ) .getMapPtr () != nullptr ,
1172
1172
UR_RESULT_ERROR_INVALID_MEM_OBJECT);
1173
- UR_ASSERT (hMem->Mem . BufferMem .getMapPtr () == pMappedPtr,
1173
+ UR_ASSERT (std::get<BufferMem>( hMem->Mem ) .getMapPtr () == pMappedPtr,
1174
1174
UR_RESULT_ERROR_INVALID_MEM_OBJECT);
1175
1175
1176
- const bool IsPinned =
1177
- hMem->Mem .BufferMem .MemAllocMode ==
1178
- ur_mem_handle_t_::MemImpl::BufferMem::AllocMode::AllocHostPtr;
1176
+ const bool IsPinned = std::get<BufferMem>(hMem->Mem ).MemAllocMode ==
1177
+ BufferMem::AllocMode::AllocHostPtr;
1179
1178
1180
- if (!IsPinned && (hMem->Mem .BufferMem .getMapFlags () & UR_MAP_FLAG_WRITE)) {
1179
+ if (!IsPinned &&
1180
+ (std::get<BufferMem>(hMem->Mem ).getMapFlags () & UR_MAP_FLAG_WRITE)) {
1181
1181
// Pinned host memory is only on host so it doesn't need to be written to.
1182
1182
Result = urEnqueueMemBufferWrite (
1183
- hQueue, hMem, true , hMem->Mem . BufferMem .getMapOffset (),
1184
- hMem->Mem . BufferMem . getMapSize (), pMappedPtr, numEventsInWaitList ,
1185
- phEventWaitList, phEvent);
1183
+ hQueue, hMem, true , std::get<BufferMem>( hMem->Mem ) .getMapOffset (),
1184
+ std::get<BufferMem>( hMem->Mem ). getMapSize (), pMappedPtr,
1185
+ numEventsInWaitList, phEventWaitList, phEvent);
1186
1186
} else {
1187
1187
ScopedContext Active (hQueue->getContext ());
1188
1188
@@ -1203,7 +1203,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemUnmap(
1203
1203
}
1204
1204
}
1205
1205
1206
- hMem->Mem . BufferMem .unmap (pMappedPtr);
1206
+ std::get<BufferMem>( hMem->Mem ) .unmap (pMappedPtr);
1207
1207
return Result;
1208
1208
}
1209
1209
@@ -1502,11 +1502,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferRead(
1502
1502
size_t offset, size_t size, void *pDst, uint32_t numEventsInWaitList,
1503
1503
const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
1504
1504
UR_ASSERT (!hBuffer->isImage (), UR_RESULT_ERROR_INVALID_MEM_OBJECT);
1505
- UR_ASSERT (offset + size <= hBuffer->Mem . BufferMem .Size ,
1505
+ UR_ASSERT (offset + size <= std::get<BufferMem>( hBuffer->Mem ) .Size ,
1506
1506
UR_RESULT_ERROR_INVALID_SIZE);
1507
1507
1508
1508
ur_result_t Result = UR_RESULT_SUCCESS;
1509
- CUdeviceptr DevPtr = hBuffer->Mem . BufferMem .get ();
1509
+ CUdeviceptr DevPtr = std::get<BufferMem>( hBuffer->Mem ) .get ();
1510
1510
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr };
1511
1511
1512
1512
try {
@@ -1549,11 +1549,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWrite(
1549
1549
size_t offset, size_t size, const void *pSrc, uint32_t numEventsInWaitList,
1550
1550
const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
1551
1551
UR_ASSERT (!hBuffer->isImage (), UR_RESULT_ERROR_INVALID_MEM_OBJECT);
1552
- UR_ASSERT (offset + size <= hBuffer->Mem . BufferMem .Size ,
1552
+ UR_ASSERT (offset + size <= std::get<BufferMem>( hBuffer->Mem ) .Size ,
1553
1553
UR_RESULT_ERROR_INVALID_SIZE);
1554
1554
1555
1555
ur_result_t Result = UR_RESULT_SUCCESS;
1556
- CUdeviceptr DevPtr = hBuffer->Mem . BufferMem .get ();
1556
+ CUdeviceptr DevPtr = std::get<BufferMem>( hBuffer->Mem ) .get ();
1557
1557
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr };
1558
1558
1559
1559
try {
0 commit comments