Skip to content

Commit 6f9292e

Browse files
authored
[UR][CUDA] Cleanup UMF pools creation code (#18361)
This code was roughly duplicated between the context and the platform, and could also be cleaned up and refactored to be clearer and not require as many helper functions and macros. It is going to create more "MemoryProviderParams" during platform initialization. But that cost of that is likely pretty small and not on the hot path, so it seems like a fair tradeoff for cleaner code.
1 parent 300bba0 commit 6f9292e

File tree

4 files changed

+49
-121
lines changed

4 files changed

+49
-121
lines changed

unified-runtime/source/adapters/cuda/common.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
#include "common.hpp"
1212
#include "logger/ur_logger.hpp"
1313

14+
#include "umf_helpers.hpp"
15+
1416
#include <cuda.h>
1517
#include <nvml.h>
1618

@@ -169,4 +171,33 @@ ur_result_t getProviderNativeError(const char *providerName, int32_t error) {
169171

170172
return UR_RESULT_ERROR_UNKNOWN;
171173
}
174+
175+
ur_result_t CreateProviderPool(int cuDevice, void *cuContext,
176+
umf_usm_memory_type_t type,
177+
umf_memory_provider_handle_t *provider,
178+
umf_memory_pool_handle_t *pool) {
179+
umf_cuda_memory_provider_params_handle_t CUMemoryProviderParams = nullptr;
180+
UMF_RETURN_UR_ERROR(
181+
umfCUDAMemoryProviderParamsCreate(&CUMemoryProviderParams));
182+
OnScopeExit Cleanup(
183+
[=]() { umfCUDAMemoryProviderParamsDestroy(CUMemoryProviderParams); });
184+
185+
// Setup memory provider parameters
186+
UMF_RETURN_UR_ERROR(
187+
umfCUDAMemoryProviderParamsSetContext(CUMemoryProviderParams, cuContext));
188+
UMF_RETURN_UR_ERROR(
189+
umfCUDAMemoryProviderParamsSetDevice(CUMemoryProviderParams, cuDevice));
190+
UMF_RETURN_UR_ERROR(
191+
umfCUDAMemoryProviderParamsSetMemoryType(CUMemoryProviderParams, type));
192+
193+
// Create memory provider
194+
UMF_RETURN_UR_ERROR(umfMemoryProviderCreate(
195+
umfCUDAMemoryProviderOps(), CUMemoryProviderParams, provider));
196+
197+
// Create memory pool
198+
UMF_RETURN_UR_ERROR(
199+
umfPoolCreate(umfProxyPoolOps(), *provider, nullptr, 0, pool));
200+
201+
return UR_RESULT_SUCCESS;
202+
}
172203
} // namespace umf

unified-runtime/source/adapters/cuda/common.hpp

Lines changed: 6 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,10 @@
1414
#include <ur/ur.hpp>
1515

1616
#include <umf/base.h>
17+
#include <umf/memory_pool.h>
18+
#include <umf/memory_provider.h>
1719
#include <umf/providers/provider_cuda.h>
1820

19-
#define UMF_RETURN_UMF_ERROR(UmfResult) \
20-
do { \
21-
umf_result_t UmfResult_ = (UmfResult); \
22-
if (UmfResult_ != UMF_RESULT_SUCCESS) { \
23-
return UmfResult_; \
24-
} \
25-
} while (0)
26-
2721
ur_result_t mapErrorUR(CUresult Result);
2822

2923
/// Converts CUDA error into UR error codes, and outputs error information
@@ -57,24 +51,8 @@ extern thread_local char ErrorMessage[MaxMessageSize];
5751
void setPluginSpecificMessage(CUresult cu_res);
5852

5953
namespace umf {
60-
61-
inline umf_result_t setCUMemoryProviderParams(
62-
umf_cuda_memory_provider_params_handle_t CUMemoryProviderParams,
63-
int cuDevice, void *cuContext, umf_usm_memory_type_t memType) {
64-
65-
umf_result_t UmfResult =
66-
umfCUDAMemoryProviderParamsSetContext(CUMemoryProviderParams, cuContext);
67-
UMF_RETURN_UMF_ERROR(UmfResult);
68-
69-
UmfResult =
70-
umfCUDAMemoryProviderParamsSetDevice(CUMemoryProviderParams, cuDevice);
71-
UMF_RETURN_UMF_ERROR(UmfResult);
72-
73-
UmfResult =
74-
umfCUDAMemoryProviderParamsSetMemoryType(CUMemoryProviderParams, memType);
75-
UMF_RETURN_UMF_ERROR(UmfResult);
76-
77-
return UMF_RESULT_SUCCESS;
78-
}
79-
54+
ur_result_t CreateProviderPool(int cuDevice, void *cuContext,
55+
umf_usm_memory_type_t type,
56+
umf_memory_provider_handle_t *provider,
57+
umf_memory_pool_handle_t *pool);
8058
} // namespace umf

unified-runtime/source/adapters/cuda/context.hpp

Lines changed: 3 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -78,38 +78,6 @@ typedef void (*ur_context_extended_deleter_t)(void *user_data);
7878
///
7979
///
8080

81-
static ur_result_t
82-
CreateHostMemoryProviderPool(ur_device_handle_t_ *DeviceHandle,
83-
umf_memory_provider_handle_t *MemoryProviderHost,
84-
umf_memory_pool_handle_t *MemoryPoolHost) {
85-
86-
*MemoryProviderHost = nullptr;
87-
CUcontext context = DeviceHandle->getNativeContext();
88-
89-
umf_cuda_memory_provider_params_handle_t CUMemoryProviderParams = nullptr;
90-
umf_result_t UmfResult =
91-
umfCUDAMemoryProviderParamsCreate(&CUMemoryProviderParams);
92-
UMF_RETURN_UR_ERROR(UmfResult);
93-
OnScopeExit Cleanup(
94-
[=]() { umfCUDAMemoryProviderParamsDestroy(CUMemoryProviderParams); });
95-
96-
UmfResult = umf::setCUMemoryProviderParams(
97-
CUMemoryProviderParams, 0 /* cuDevice */, context, UMF_MEMORY_TYPE_HOST);
98-
UMF_RETURN_UR_ERROR(UmfResult);
99-
100-
// create UMF CUDA memory provider and pool for the host memory
101-
// (UMF_MEMORY_TYPE_HOST)
102-
UmfResult = umfMemoryProviderCreate(
103-
umfCUDAMemoryProviderOps(), CUMemoryProviderParams, MemoryProviderHost);
104-
UMF_RETURN_UR_ERROR(UmfResult);
105-
106-
UmfResult = umfPoolCreate(umfProxyPoolOps(), *MemoryProviderHost, nullptr, 0,
107-
MemoryPoolHost);
108-
UMF_RETURN_UR_ERROR(UmfResult);
109-
110-
return UR_RESULT_SUCCESS;
111-
}
112-
11381
struct ur_context_handle_t_ {
11482

11583
struct deleter_data {
@@ -132,8 +100,9 @@ struct ur_context_handle_t_ {
132100
// Create UMF CUDA memory provider for the host memory
133101
// (UMF_MEMORY_TYPE_HOST) from any device (Devices[0] is used here, because
134102
// it is guaranteed to exist).
135-
UR_CHECK_ERROR(CreateHostMemoryProviderPool(Devices[0], &MemoryProviderHost,
136-
&MemoryPoolHost));
103+
UR_CHECK_ERROR(umf::CreateProviderPool(
104+
0, Devices[0]->getNativeContext(), UMF_MEMORY_TYPE_HOST,
105+
&MemoryProviderHost, &MemoryPoolHost));
137106
UR_CHECK_ERROR(urAdapterRetain(ur::cuda::adapter));
138107
};
139108

unified-runtime/source/adapters/cuda/platform.cpp

Lines changed: 9 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -19,62 +19,6 @@
1919
#include <cuda.h>
2020
#include <sstream>
2121

22-
static ur_result_t
23-
CreateDeviceMemoryProvidersPools(ur_platform_handle_t_ *Platform) {
24-
umf_cuda_memory_provider_params_handle_t CUMemoryProviderParams = nullptr;
25-
26-
umf_result_t UmfResult =
27-
umfCUDAMemoryProviderParamsCreate(&CUMemoryProviderParams);
28-
UMF_RETURN_UR_ERROR(UmfResult);
29-
30-
OnScopeExit Cleanup(
31-
[=]() { umfCUDAMemoryProviderParamsDestroy(CUMemoryProviderParams); });
32-
33-
for (auto &Device : Platform->Devices) {
34-
ur_device_handle_t_ *device_handle = Device.get();
35-
CUdevice device = device_handle->get();
36-
CUcontext context = device_handle->getNativeContext();
37-
38-
// create UMF CUDA memory provider for the device memory
39-
// (UMF_MEMORY_TYPE_DEVICE)
40-
UmfResult = umf::setCUMemoryProviderParams(CUMemoryProviderParams, device,
41-
context, UMF_MEMORY_TYPE_DEVICE);
42-
UMF_RETURN_UR_ERROR(UmfResult);
43-
44-
UmfResult = umfMemoryProviderCreate(umfCUDAMemoryProviderOps(),
45-
CUMemoryProviderParams,
46-
&device_handle->MemoryProviderDevice);
47-
UMF_RETURN_UR_ERROR(UmfResult);
48-
49-
// create UMF CUDA memory provider for the shared memory
50-
// (UMF_MEMORY_TYPE_SHARED)
51-
UmfResult = umf::setCUMemoryProviderParams(CUMemoryProviderParams, device,
52-
context, UMF_MEMORY_TYPE_SHARED);
53-
UMF_RETURN_UR_ERROR(UmfResult);
54-
55-
UmfResult = umfMemoryProviderCreate(umfCUDAMemoryProviderOps(),
56-
CUMemoryProviderParams,
57-
&device_handle->MemoryProviderShared);
58-
UMF_RETURN_UR_ERROR(UmfResult);
59-
60-
// create UMF CUDA memory pool for the device memory
61-
// (UMF_MEMORY_TYPE_DEVICE)
62-
UmfResult =
63-
umfPoolCreate(umfProxyPoolOps(), device_handle->MemoryProviderDevice,
64-
nullptr, 0, &device_handle->MemoryPoolDevice);
65-
UMF_RETURN_UR_ERROR(UmfResult);
66-
67-
// create UMF CUDA memory pool for the shared memory
68-
// (UMF_MEMORY_TYPE_SHARED)
69-
UmfResult =
70-
umfPoolCreate(umfProxyPoolOps(), device_handle->MemoryProviderShared,
71-
nullptr, 0, &device_handle->MemoryPoolShared);
72-
UMF_RETURN_UR_ERROR(UmfResult);
73-
}
74-
75-
return UR_RESULT_SUCCESS;
76-
}
77-
7822
UR_APIEXPORT ur_result_t UR_APICALL
7923
urPlatformGetInfo(ur_platform_handle_t, ur_platform_info_t PlatformInfoType,
8024
size_t Size, void *pPlatformInfo, size_t *pSizeRet) {
@@ -148,10 +92,16 @@ urPlatformGet(ur_adapter_handle_t, uint32_t, ur_platform_handle_t *phPlatforms,
14892
new ur_device_handle_t_{Device, Context, EvBase,
14993
ur::cuda::adapter->Platform.get(),
15094
static_cast<uint32_t>(i)});
151-
}
15295

153-
UR_CHECK_ERROR(CreateDeviceMemoryProvidersPools(
154-
ur::cuda::adapter->Platform.get()));
96+
// Create UMF memory providers and pools
97+
auto &dev = ur::cuda::adapter->Platform->Devices.back();
98+
UR_CHECK_ERROR(umf::CreateProviderPool(
99+
Device, Context, UMF_MEMORY_TYPE_DEVICE,
100+
&dev->MemoryProviderDevice, &dev->MemoryPoolDevice));
101+
UR_CHECK_ERROR(umf::CreateProviderPool(
102+
Device, Context, UMF_MEMORY_TYPE_SHARED,
103+
&dev->MemoryProviderShared, &dev->MemoryPoolShared));
104+
}
155105
} catch (const std::bad_alloc &) {
156106
// Signal out-of-memory situation
157107
for (int i = 0; i < NumDevices; ++i) {

0 commit comments

Comments
 (0)