Skip to content

Commit 86f96f0

Browse files
author
Hugh Delaney
committed
Refactor unions to using std::variant
1 parent 614e6d0 commit 86f96f0

File tree

8 files changed

+416
-433
lines changed

8 files changed

+416
-433
lines changed

source/adapters/cuda/enqueue.cpp

Lines changed: 39 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -514,7 +514,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferReadRect(
514514
uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
515515
ur_event_handle_t *phEvent) {
516516
ur_result_t Result = UR_RESULT_SUCCESS;
517-
CUdeviceptr DevPtr = hBuffer->Mem.BufferMem.get();
517+
CUdeviceptr DevPtr = std::get<BufferMem>(hBuffer->Mem).get();
518518
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr};
519519

520520
try {
@@ -562,7 +562,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWriteRect(
562562
uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
563563
ur_event_handle_t *phEvent) {
564564
ur_result_t Result = UR_RESULT_SUCCESS;
565-
CUdeviceptr DevPtr = hBuffer->Mem.BufferMem.get();
565+
CUdeviceptr DevPtr = std::get<BufferMem>(hBuffer->Mem).get();
566566
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr};
567567

568568
try {
@@ -606,9 +606,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferCopy(
606606
ur_mem_handle_t hBufferDst, size_t srcOffset, size_t dstOffset, size_t size,
607607
uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
608608
ur_event_handle_t *phEvent) {
609-
UR_ASSERT(size + dstOffset <= hBufferDst->Mem.BufferMem.getSize(),
609+
UR_ASSERT(size + dstOffset <= std::get<BufferMem>(hBufferDst->Mem).getSize(),
610610
UR_RESULT_ERROR_INVALID_SIZE);
611-
UR_ASSERT(size + srcOffset <= hBufferSrc->Mem.BufferMem.getSize(),
611+
UR_ASSERT(size + srcOffset <= std::get<BufferMem>(hBufferSrc->Mem).getSize(),
612612
UR_RESULT_ERROR_INVALID_SIZE);
613613

614614
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr};
@@ -628,8 +628,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferCopy(
628628
UR_CHECK_ERROR(RetImplEvent->start());
629629
}
630630

631-
auto Src = hBufferSrc->Mem.BufferMem.get() + srcOffset;
632-
auto Dst = hBufferDst->Mem.BufferMem.get() + dstOffset;
631+
auto Src = std::get<BufferMem>(hBufferSrc->Mem).get() + srcOffset;
632+
auto Dst = std::get<BufferMem>(hBufferDst->Mem).get() + dstOffset;
633633

634634
UR_CHECK_ERROR(cuMemcpyDtoDAsync(Dst, Src, size, Stream));
635635

@@ -654,8 +654,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferCopyRect(
654654
uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
655655
ur_event_handle_t *phEvent) {
656656
ur_result_t Result = UR_RESULT_SUCCESS;
657-
CUdeviceptr SrcPtr = hBufferSrc->Mem.BufferMem.get();
658-
CUdeviceptr DstPtr = hBufferDst->Mem.BufferMem.get();
657+
CUdeviceptr SrcPtr = std::get<BufferMem>(hBufferSrc->Mem).get();
658+
CUdeviceptr DstPtr = std::get<BufferMem>(hBufferDst->Mem).get();
659659
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr};
660660

661661
try {
@@ -726,7 +726,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferFill(
726726
size_t patternSize, size_t offset, size_t size,
727727
uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
728728
ur_event_handle_t *phEvent) {
729-
UR_ASSERT(size + offset <= hBuffer->Mem.BufferMem.getSize(),
729+
UR_ASSERT(size + offset <= std::get<BufferMem>(hBuffer->Mem).getSize(),
730730
UR_RESULT_ERROR_INVALID_SIZE);
731731

732732
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr};
@@ -745,7 +745,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferFill(
745745
UR_CHECK_ERROR(RetImplEvent->start());
746746
}
747747

748-
auto DstDevice = hBuffer->Mem.BufferMem.get() + offset;
748+
auto DstDevice = std::get<BufferMem>(hBuffer->Mem).get() + offset;
749749
auto N = size / patternSize;
750750

751751
// pattern size in bytes
@@ -892,7 +892,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageRead(
892892
Result = enqueueEventsWait(hQueue, CuStream, numEventsInWaitList,
893893
phEventWaitList);
894894

895-
CUarray Array = hImage->Mem.SurfaceMem.getArray();
895+
CUarray Array = std::get<SurfaceMem>(hImage->Mem).getArray();
896896

897897
CUDA_ARRAY_DESCRIPTOR ArrayDesc;
898898
UR_CHECK_ERROR(cuArrayGetDescriptor(&ArrayDesc, Array));
@@ -902,7 +902,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageRead(
902902
size_t ByteOffsetX = origin.x * ElementByteSize * ArrayDesc.NumChannels;
903903
size_t BytesToCopy = ElementByteSize * ArrayDesc.NumChannels * region.width;
904904

905-
ur_mem_type_t ImgType = hImage->Mem.SurfaceMem.getImageType();
905+
ur_mem_type_t ImgType = std::get<SurfaceMem>(hImage->Mem).getImageType();
906906

907907
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr};
908908
if (phEvent) {
@@ -964,7 +964,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageWrite(
964964
Result = enqueueEventsWait(hQueue, CuStream, numEventsInWaitList,
965965
phEventWaitList);
966966

967-
CUarray Array = hImage->Mem.SurfaceMem.getArray();
967+
CUarray Array = std::get<SurfaceMem>(hImage->Mem).getArray();
968968

969969
CUDA_ARRAY_DESCRIPTOR ArrayDesc;
970970
UR_CHECK_ERROR(cuArrayGetDescriptor(&ArrayDesc, Array));
@@ -982,7 +982,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageWrite(
982982
UR_CHECK_ERROR(RetImplEvent->start());
983983
}
984984

985-
ur_mem_type_t ImgType = hImage->Mem.SurfaceMem.getImageType();
985+
ur_mem_type_t ImgType = std::get<SurfaceMem>(hImage->Mem).getImageType();
986986
if (ImgType == UR_MEM_TYPE_IMAGE1D) {
987987
UR_CHECK_ERROR(
988988
cuMemcpyHtoAAsync(Array, ByteOffsetX, pSrc, BytesToCopy, CuStream));
@@ -1023,8 +1023,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageCopy(
10231023
UR_RESULT_ERROR_INVALID_MEM_OBJECT);
10241024
UR_ASSERT(hImageDst->MemType == ur_mem_handle_t_::Type::Surface,
10251025
UR_RESULT_ERROR_INVALID_MEM_OBJECT);
1026-
UR_ASSERT(hImageSrc->Mem.SurfaceMem.getImageType() ==
1027-
hImageDst->Mem.SurfaceMem.getImageType(),
1026+
UR_ASSERT(std::get<SurfaceMem>(hImageSrc->Mem).getImageType() ==
1027+
std::get<SurfaceMem>(hImageDst->Mem).getImageType(),
10281028
UR_RESULT_ERROR_INVALID_MEM_OBJECT);
10291029

10301030
ur_result_t Result = UR_RESULT_SUCCESS;
@@ -1035,8 +1035,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageCopy(
10351035
Result = enqueueEventsWait(hQueue, CuStream, numEventsInWaitList,
10361036
phEventWaitList);
10371037

1038-
CUarray SrcArray = hImageSrc->Mem.SurfaceMem.getArray();
1039-
CUarray DstArray = hImageDst->Mem.SurfaceMem.getArray();
1038+
CUarray SrcArray = std::get<SurfaceMem>(hImageSrc->Mem).getArray();
1039+
CUarray DstArray = std::get<SurfaceMem>(hImageDst->Mem).getArray();
10401040

10411041
CUDA_ARRAY_DESCRIPTOR SrcArrayDesc;
10421042
UR_CHECK_ERROR(cuArrayGetDescriptor(&SrcArrayDesc, SrcArray));
@@ -1065,7 +1065,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageCopy(
10651065
UR_CHECK_ERROR(RetImplEvent->start());
10661066
}
10671067

1068-
ur_mem_type_t ImgType = hImageSrc->Mem.SurfaceMem.getImageType();
1068+
ur_mem_type_t ImgType = std::get<SurfaceMem>(hImageSrc->Mem).getImageType();
10691069
if (ImgType == UR_MEM_TYPE_IMAGE1D) {
10701070
UR_CHECK_ERROR(cuMemcpyAtoA(DstArray, DstByteOffsetX, SrcArray,
10711071
SrcByteOffsetX, BytesToCopy));
@@ -1108,22 +1108,22 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferMap(
11081108
ur_event_handle_t *phEvent, void **ppRetMap) {
11091109
UR_ASSERT(hBuffer->MemType == ur_mem_handle_t_::Type::Buffer,
11101110
UR_RESULT_ERROR_INVALID_MEM_OBJECT);
1111-
UR_ASSERT(offset + size <= hBuffer->Mem.BufferMem.getSize(),
1111+
UR_ASSERT(offset + size <= std::get<BufferMem>(hBuffer->Mem).getSize(),
11121112
UR_RESULT_ERROR_INVALID_SIZE);
11131113

1114+
auto &BufferImpl = std::get<BufferMem>(hBuffer->Mem);
11141115
ur_result_t Result = UR_RESULT_ERROR_INVALID_MEM_OBJECT;
11151116
const bool IsPinned =
1116-
hBuffer->Mem.BufferMem.MemAllocMode ==
1117-
ur_mem_handle_t_::MemImpl::BufferMem::AllocMode::AllocHostPtr;
1117+
BufferImpl.MemAllocMode == BufferMem::AllocMode::AllocHostPtr;
11181118

11191119
// Currently no support for overlapping regions
1120-
if (hBuffer->Mem.BufferMem.getMapPtr() != nullptr) {
1120+
if (BufferImpl.getMapPtr() != nullptr) {
11211121
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
11221122
}
11231123

11241124
// Allocate a pointer in the host to store the mapped information
1125-
auto HostPtr = hBuffer->Mem.BufferMem.mapToPtr(size, offset, mapFlags);
1126-
*ppRetMap = hBuffer->Mem.BufferMem.getMapPtr();
1125+
auto HostPtr = BufferImpl.mapToPtr(size, offset, mapFlags);
1126+
*ppRetMap = BufferImpl.getMapPtr();
11271127
if (HostPtr) {
11281128
Result = UR_RESULT_SUCCESS;
11291129
}
@@ -1168,21 +1168,21 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemUnmap(
11681168
ur_result_t Result = UR_RESULT_SUCCESS;
11691169
UR_ASSERT(hMem->MemType == ur_mem_handle_t_::Type::Buffer,
11701170
UR_RESULT_ERROR_INVALID_MEM_OBJECT);
1171-
UR_ASSERT(hMem->Mem.BufferMem.getMapPtr() != nullptr,
1171+
UR_ASSERT(std::get<BufferMem>(hMem->Mem).getMapPtr() != nullptr,
11721172
UR_RESULT_ERROR_INVALID_MEM_OBJECT);
1173-
UR_ASSERT(hMem->Mem.BufferMem.getMapPtr() == pMappedPtr,
1173+
UR_ASSERT(std::get<BufferMem>(hMem->Mem).getMapPtr() == pMappedPtr,
11741174
UR_RESULT_ERROR_INVALID_MEM_OBJECT);
11751175

1176-
const bool IsPinned =
1177-
hMem->Mem.BufferMem.MemAllocMode ==
1178-
ur_mem_handle_t_::MemImpl::BufferMem::AllocMode::AllocHostPtr;
1176+
const bool IsPinned = std::get<BufferMem>(hMem->Mem).MemAllocMode ==
1177+
BufferMem::AllocMode::AllocHostPtr;
11791178

1180-
if (!IsPinned && (hMem->Mem.BufferMem.getMapFlags() & UR_MAP_FLAG_WRITE)) {
1179+
if (!IsPinned &&
1180+
(std::get<BufferMem>(hMem->Mem).getMapFlags() & UR_MAP_FLAG_WRITE)) {
11811181
// Pinned host memory is only on host so it doesn't need to be written to.
11821182
Result = urEnqueueMemBufferWrite(
1183-
hQueue, hMem, true, hMem->Mem.BufferMem.getMapOffset(),
1184-
hMem->Mem.BufferMem.getMapSize(), pMappedPtr, numEventsInWaitList,
1185-
phEventWaitList, phEvent);
1183+
hQueue, hMem, true, std::get<BufferMem>(hMem->Mem).getMapOffset(),
1184+
std::get<BufferMem>(hMem->Mem).getMapSize(), pMappedPtr,
1185+
numEventsInWaitList, phEventWaitList, phEvent);
11861186
} else {
11871187
ScopedContext Active(hQueue->getContext());
11881188

@@ -1203,7 +1203,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemUnmap(
12031203
}
12041204
}
12051205

1206-
hMem->Mem.BufferMem.unmap(pMappedPtr);
1206+
std::get<BufferMem>(hMem->Mem).unmap(pMappedPtr);
12071207
return Result;
12081208
}
12091209

@@ -1502,11 +1502,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferRead(
15021502
size_t offset, size_t size, void *pDst, uint32_t numEventsInWaitList,
15031503
const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
15041504
UR_ASSERT(!hBuffer->isImage(), UR_RESULT_ERROR_INVALID_MEM_OBJECT);
1505-
UR_ASSERT(offset + size <= hBuffer->Mem.BufferMem.Size,
1505+
UR_ASSERT(offset + size <= std::get<BufferMem>(hBuffer->Mem).Size,
15061506
UR_RESULT_ERROR_INVALID_SIZE);
15071507

15081508
ur_result_t Result = UR_RESULT_SUCCESS;
1509-
CUdeviceptr DevPtr = hBuffer->Mem.BufferMem.get();
1509+
CUdeviceptr DevPtr = std::get<BufferMem>(hBuffer->Mem).get();
15101510
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr};
15111511

15121512
try {
@@ -1549,11 +1549,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWrite(
15491549
size_t offset, size_t size, const void *pSrc, uint32_t numEventsInWaitList,
15501550
const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
15511551
UR_ASSERT(!hBuffer->isImage(), UR_RESULT_ERROR_INVALID_MEM_OBJECT);
1552-
UR_ASSERT(offset + size <= hBuffer->Mem.BufferMem.Size,
1552+
UR_ASSERT(offset + size <= std::get<BufferMem>(hBuffer->Mem).Size,
15531553
UR_RESULT_ERROR_INVALID_SIZE);
15541554

15551555
ur_result_t Result = UR_RESULT_SUCCESS;
1556-
CUdeviceptr DevPtr = hBuffer->Mem.BufferMem.get();
1556+
CUdeviceptr DevPtr = std::get<BufferMem>(hBuffer->Mem).get();
15571557
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr};
15581558

15591559
try {

source/adapters/cuda/kernel.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,7 @@ urKernelSetArgMemObj(ur_kernel_handle_t hKernel, uint32_t argIndex,
304304
if (hArgValue->MemType == ur_mem_handle_t_::Type::Surface) {
305305
CUDA_ARRAY3D_DESCRIPTOR arrayDesc;
306306
UR_CHECK_ERROR(cuArray3DGetDescriptor(
307-
&arrayDesc, hArgValue->Mem.SurfaceMem.getArray()));
307+
&arrayDesc, std::get<SurfaceMem>(hArgValue->Mem).getArray()));
308308
if (arrayDesc.Format != CU_AD_FORMAT_UNSIGNED_INT32 &&
309309
arrayDesc.Format != CU_AD_FORMAT_SIGNED_INT32 &&
310310
arrayDesc.Format != CU_AD_FORMAT_HALF &&
@@ -314,10 +314,10 @@ urKernelSetArgMemObj(ur_kernel_handle_t hKernel, uint32_t argIndex,
314314
UR_RESULT_ERROR_ADAPTER_SPECIFIC);
315315
return UR_RESULT_ERROR_ADAPTER_SPECIFIC;
316316
}
317-
CUsurfObject CuSurf = hArgValue->Mem.SurfaceMem.getSurface();
317+
CUsurfObject CuSurf = std::get<SurfaceMem>(hArgValue->Mem).getSurface();
318318
hKernel->setKernelArg(argIndex, sizeof(CuSurf), (void *)&CuSurf);
319319
} else {
320-
CUdeviceptr CuPtr = hArgValue->Mem.BufferMem.get();
320+
CUdeviceptr CuPtr = std::get<BufferMem>(hArgValue->Mem).get();
321321
hKernel->setKernelArg(argIndex, sizeof(CUdeviceptr), (void *)&CuPtr);
322322
}
323323
} catch (ur_result_t Err) {

source/adapters/cuda/memory.cpp

Lines changed: 29 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -44,22 +44,21 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemBufferCreate(
4444
CUdeviceptr Ptr = 0;
4545
auto HostPtr = pProperties ? pProperties->pHost : nullptr;
4646

47-
ur_mem_handle_t_::MemImpl::BufferMem::AllocMode AllocMode =
48-
ur_mem_handle_t_::MemImpl::BufferMem::AllocMode::Classic;
47+
BufferMem::AllocMode AllocMode = BufferMem::AllocMode::Classic;
4948

5049
if ((flags & UR_MEM_FLAG_USE_HOST_POINTER) && EnableUseHostPtr) {
5150
UR_CHECK_ERROR(
5251
cuMemHostRegister(HostPtr, size, CU_MEMHOSTREGISTER_DEVICEMAP));
5352
UR_CHECK_ERROR(cuMemHostGetDevicePointer(&Ptr, HostPtr, 0));
54-
AllocMode = ur_mem_handle_t_::MemImpl::BufferMem::AllocMode::UseHostPtr;
53+
AllocMode = BufferMem::AllocMode::UseHostPtr;
5554
} else if (flags & UR_MEM_FLAG_ALLOC_HOST_POINTER) {
5655
UR_CHECK_ERROR(cuMemAllocHost(&HostPtr, size));
5756
UR_CHECK_ERROR(cuMemHostGetDevicePointer(&Ptr, HostPtr, 0));
58-
AllocMode = ur_mem_handle_t_::MemImpl::BufferMem::AllocMode::AllocHostPtr;
57+
AllocMode = BufferMem::AllocMode::AllocHostPtr;
5958
} else {
6059
UR_CHECK_ERROR(cuMemAlloc(&Ptr, size));
6160
if (flags & UR_MEM_FLAG_ALLOC_COPY_HOST_POINTER) {
62-
AllocMode = ur_mem_handle_t_::MemImpl::BufferMem::AllocMode::CopyIn;
61+
AllocMode = BufferMem::AllocMode::CopyIn;
6362
}
6463
}
6564

@@ -121,21 +120,22 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemRelease(ur_mem_handle_t hMem) {
121120
ScopedContext Active(MemObjPtr->getContext());
122121

123122
if (hMem->MemType == ur_mem_handle_t_::Type::Buffer) {
124-
switch (MemObjPtr->Mem.BufferMem.MemAllocMode) {
125-
case ur_mem_handle_t_::MemImpl::BufferMem::AllocMode::CopyIn:
126-
case ur_mem_handle_t_::MemImpl::BufferMem::AllocMode::Classic:
127-
UR_CHECK_ERROR(cuMemFree(MemObjPtr->Mem.BufferMem.Ptr));
123+
auto &BufferImpl = std::get<BufferMem>(MemObjPtr->Mem);
124+
switch (BufferImpl.MemAllocMode) {
125+
case BufferMem::AllocMode::CopyIn:
126+
case BufferMem::AllocMode::Classic:
127+
UR_CHECK_ERROR(cuMemFree(BufferImpl.Ptr));
128128
break;
129-
case ur_mem_handle_t_::MemImpl::BufferMem::AllocMode::UseHostPtr:
130-
UR_CHECK_ERROR(cuMemHostUnregister(MemObjPtr->Mem.BufferMem.HostPtr));
129+
case BufferMem::AllocMode::UseHostPtr:
130+
UR_CHECK_ERROR(cuMemHostUnregister(BufferImpl.HostPtr));
131131
break;
132-
case ur_mem_handle_t_::MemImpl::BufferMem::AllocMode::AllocHostPtr:
133-
UR_CHECK_ERROR(cuMemFreeHost(MemObjPtr->Mem.BufferMem.HostPtr));
132+
case BufferMem::AllocMode::AllocHostPtr:
133+
UR_CHECK_ERROR(cuMemFreeHost(BufferImpl.HostPtr));
134134
};
135135
} else if (hMem->MemType == ur_mem_handle_t_::Type::Surface) {
136-
UR_CHECK_ERROR(
137-
cuSurfObjectDestroy(MemObjPtr->Mem.SurfaceMem.getSurface()));
138-
UR_CHECK_ERROR(cuArrayDestroy(MemObjPtr->Mem.SurfaceMem.getArray()));
136+
auto &SurfaceImpl = std::get<SurfaceMem>(MemObjPtr->Mem);
137+
UR_CHECK_ERROR(cuSurfObjectDestroy(SurfaceImpl.getSurface()));
138+
UR_CHECK_ERROR(cuArrayDestroy(SurfaceImpl.getArray()));
139139
}
140140

141141
} catch (ur_result_t Err) {
@@ -163,8 +163,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemRelease(ur_mem_handle_t hMem) {
163163
/// \return UR_RESULT_SUCCESS
164164
UR_APIEXPORT ur_result_t UR_APICALL
165165
urMemGetNativeHandle(ur_mem_handle_t hMem, ur_native_handle_t *phNativeMem) {
166-
*phNativeMem =
167-
reinterpret_cast<ur_native_handle_t>(hMem->Mem.BufferMem.get());
166+
*phNativeMem = reinterpret_cast<ur_native_handle_t>(
167+
std::get<BufferMem>(hMem->Mem).get());
168168
return UR_RESULT_SUCCESS;
169169
}
170170

@@ -183,8 +183,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemGetInfo(ur_mem_handle_t hMemory,
183183
case UR_MEM_INFO_SIZE: {
184184
try {
185185
size_t AllocSize = 0;
186-
UR_CHECK_ERROR(cuMemGetAddressRange(nullptr, &AllocSize,
187-
hMemory->Mem.BufferMem.Ptr));
186+
UR_CHECK_ERROR(cuMemGetAddressRange(
187+
nullptr, &AllocSize, std::get<BufferMem>(hMemory->Mem).Ptr));
188188
return ReturnValue(AllocSize);
189189
} catch (ur_result_t Err) {
190190
return Err;
@@ -443,25 +443,22 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemBufferPartition(
443443

444444
UR_ASSERT(pRegion->size != 0u, UR_RESULT_ERROR_INVALID_BUFFER_SIZE);
445445

446+
auto &BufferImpl = std::get<BufferMem>(hBuffer->Mem);
447+
446448
assert((pRegion->origin <= (pRegion->origin + pRegion->size)) && "Overflow");
447-
UR_ASSERT(
448-
((pRegion->origin + pRegion->size) <= hBuffer->Mem.BufferMem.getSize()),
449-
UR_RESULT_ERROR_INVALID_BUFFER_SIZE);
449+
UR_ASSERT(((pRegion->origin + pRegion->size) <= BufferImpl.getSize()),
450+
UR_RESULT_ERROR_INVALID_BUFFER_SIZE);
450451
// Retained indirectly due to retaining parent buffer below.
451452
ur_context_handle_t Context = hBuffer->Context;
452453

453-
ur_mem_handle_t_::MemImpl::BufferMem::AllocMode AllocMode =
454-
ur_mem_handle_t_::MemImpl::BufferMem::AllocMode::Classic;
454+
BufferMem::AllocMode AllocMode = BufferMem::AllocMode::Classic;
455455

456-
assert(hBuffer->Mem.BufferMem.Ptr !=
457-
ur_mem_handle_t_::MemImpl::BufferMem::native_type{0});
458-
ur_mem_handle_t_::MemImpl::BufferMem::native_type Ptr =
459-
hBuffer->Mem.BufferMem.Ptr + pRegion->origin;
456+
assert(BufferImpl.Ptr != BufferMem::native_type{0});
457+
BufferMem::native_type Ptr = BufferImpl.Ptr + pRegion->origin;
460458

461459
void *HostPtr = nullptr;
462-
if (hBuffer->Mem.BufferMem.HostPtr) {
463-
HostPtr =
464-
static_cast<char *>(hBuffer->Mem.BufferMem.HostPtr) + pRegion->origin;
460+
if (BufferImpl.HostPtr) {
461+
HostPtr = static_cast<char *>(BufferImpl.HostPtr) + pRegion->origin;
465462
}
466463

467464
std::unique_ptr<ur_mem_handle_t_> MemObj{nullptr};

0 commit comments

Comments
 (0)