Skip to content

Commit 4b554ef

Browse files
committed
Update Cuda provider config API
1 parent dffe4e6 commit 4b554ef

File tree

12 files changed

+481
-81
lines changed

12 files changed

+481
-81
lines changed

examples/cuda_shared_memory/cuda_shared_memory.c

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -45,22 +45,49 @@ int main(void) {
4545

4646
// Setup parameters for the CUDA memory provider. It will be used for
4747
// allocating memory from CUDA devices.
48-
cuda_memory_provider_params_t cu_memory_provider_params;
49-
cu_memory_provider_params.cuda_context_handle = cuContext;
50-
cu_memory_provider_params.cuda_device_handle = cuDevice;
48+
umf_cuda_memory_provider_params_handle_t cu_memory_provider_params = NULL;
49+
res = umfCudaMemoryProviderParamsCreate(&cu_memory_provider_params);
50+
if (res != UMF_RESULT_SUCCESS) {
51+
fprintf(stderr, "Failed to create memory provider params!\n");
52+
ret = -1;
53+
goto cuda_destroy;
54+
}
55+
56+
res = umfCudaMemoryProviderParamsSetContext(cu_memory_provider_params,
57+
cuContext);
58+
if (res != UMF_RESULT_SUCCESS) {
59+
fprintf(stderr, "Failed to set context in memory provider params!\n");
60+
ret = -1;
61+
goto provider_params_destroy;
62+
}
63+
64+
res = umfCudaMemoryProviderParamsSetDevice(cu_memory_provider_params,
65+
cuDevice);
66+
if (res != UMF_RESULT_SUCCESS) {
67+
fprintf(stderr, "Failed to set device in memory provider params!\n");
68+
ret = -1;
69+
goto provider_params_destroy;
70+
}
5171
// Set the memory type to shared to allow the memory to be accessed on both
5272
// CPU and GPU.
53-
cu_memory_provider_params.memory_type = UMF_MEMORY_TYPE_SHARED;
73+
res = umfCudaMemoryProviderParamsSetMemoryType(cu_memory_provider_params,
74+
UMF_MEMORY_TYPE_SHARED);
75+
if (res != UMF_RESULT_SUCCESS) {
76+
fprintf(stderr,
77+
"Failed to set memory type in memory provider params!\n");
78+
ret = -1;
79+
goto provider_params_destroy;
80+
}
5481

5582
// Create CUDA memory provider
5683
umf_memory_provider_handle_t cu_memory_provider;
57-
res = umfMemoryProviderCreate(umfCUDAMemoryProviderOps(),
58-
&cu_memory_provider_params,
59-
&cu_memory_provider);
84+
res =
85+
umfMemoryProviderCreate(umfCUDAMemoryProviderOps(),
86+
cu_memory_provider_params, &cu_memory_provider);
6087
if (res != UMF_RESULT_SUCCESS) {
6188
fprintf(stderr, "Failed to create a memory provider!\n");
6289
ret = -1;
63-
goto cuda_destroy;
90+
goto provider_params_destroy;
6491
}
6592

6693
printf("CUDA memory provider created at %p\n", (void *)cu_memory_provider);
@@ -147,6 +174,9 @@ int main(void) {
147174
memory_provider_destroy:
148175
umfMemoryProviderDestroy(cu_memory_provider);
149176

177+
provider_params_destroy:
178+
umfCudaMemoryProviderParamsDestroy(cu_memory_provider_params);
179+
150180
cuda_destroy:
151181
ret = cuCtxDestroy(cuContext);
152182
return ret;

include/umf/providers/provider_cuda.h

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,44 @@
1414
extern "C" {
1515
#endif
1616

17-
/// @brief CUDA Memory Provider settings struct
18-
typedef struct cuda_memory_provider_params_t {
19-
void *cuda_context_handle; ///< Handle to the CUDA context
20-
int cuda_device_handle; ///< Handle to the CUDA device
21-
umf_usm_memory_type_t memory_type; ///< Allocation memory type
22-
} cuda_memory_provider_params_t;
17+
struct umf_cuda_memory_provider_params_t;
18+
19+
typedef struct umf_cuda_memory_provider_params_t
20+
*umf_cuda_memory_provider_params_handle_t;
21+
22+
/// @brief Create a struct to store parameters of the Cuda memory provider.
23+
/// @param hParams [out] handle to the newly created parameters struct.
24+
/// @return UMF_RESULT_SUCCESS on success or appropriate error code on failure.
25+
umf_result_t umfCudaMemoryProviderParamsCreate(
26+
umf_cuda_memory_provider_params_handle_t *hParams);
27+
28+
/// @brief Destroy parameters struct.
29+
/// @param hParams handle to the parameters of the Cuda memory provider.
30+
/// @return UMF_RESULT_SUCCESS on success or appropriate error code on failure.
31+
umf_result_t umfCudaMemoryProviderParamsDestroy(
32+
umf_cuda_memory_provider_params_handle_t hParams);
33+
34+
/// @brief Set the Cuda context handle in the parameters struct.
35+
/// @param hParams handle to the parameters of the Cuda memory provider.
36+
/// @param hContext handle to the Cuda context.
37+
/// @return UMF_RESULT_SUCCESS on success or appropriate error code on failure.
38+
umf_result_t umfCudaMemoryProviderParamsSetContext(
39+
umf_cuda_memory_provider_params_handle_t hParams, void *hContext);
40+
41+
/// @brief Set the Cuda device handle in the parameters struct.
42+
/// @param hParams handle to the parameters of the Cuda memory provider.
43+
/// @param hDevice handle to the Cuda device.
44+
/// @return UMF_RESULT_SUCCESS on success or appropriate error code on failure.
45+
umf_result_t umfCudaMemoryProviderParamsSetDevice(
46+
umf_cuda_memory_provider_params_handle_t hParams, int hDevice);
47+
48+
/// @brief Set the memory type in the parameters struct.
49+
/// @param hParams handle to the parameters of the Cuda memory provider.
50+
/// @param memoryType memory type.
51+
/// @return UMF_RESULT_SUCCESS on success or appropriate error code on failure.
52+
umf_result_t umfCudaMemoryProviderParamsSetMemoryType(
53+
umf_cuda_memory_provider_params_handle_t hParams,
54+
umf_usm_memory_type_t memoryType);
2355

2456
umf_memory_provider_ops_t *umfCUDAMemoryProviderOps(void);
2557

src/libumf.def

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,11 @@ EXPORTS
1717
umfCoarseMemoryProviderGetStats
1818
umfCoarseMemoryProviderOps
1919
umfCUDAMemoryProviderOps
20+
umfCudaMemoryProviderParamsCreate
21+
umfCudaMemoryProviderParamsDestroy
22+
umfCudaMemoryProviderParamsSetContext
23+
umfCudaMemoryProviderParamsSetDevice
24+
umfCudaMemoryProviderParamsSetMemoryType
2025
umfDevDaxMemoryProviderOps
2126
umfFree
2227
umfFileMemoryProviderOps

src/libumf.map

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,11 @@ UMF_1.0 {
1111
umfCoarseMemoryProviderGetStats;
1212
umfCoarseMemoryProviderOps;
1313
umfCUDAMemoryProviderOps;
14+
umfCudaMemoryProviderParamsCreate;
15+
umfCudaMemoryProviderParamsDestroy;
16+
umfCudaMemoryProviderParamsSetContext;
17+
umfCudaMemoryProviderParamsSetDevice;
18+
umfCudaMemoryProviderParamsSetMemoryType;
1419
umfDevDaxMemoryProviderOps;
1520
umfFree;
1621
umfFileMemoryProviderOps;

src/provider/provider_cuda.c

Lines changed: 110 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,40 @@
1414

1515
#if defined(UMF_NO_CUDA_PROVIDER)
1616

17+
umf_result_t umfCudaMemoryProviderParamsCreate(
18+
umf_cuda_memory_provider_params_handle_t *hParams) {
19+
(void)hParams;
20+
return UMF_RESULT_ERROR_NOT_SUPPORTED;
21+
}
22+
23+
umf_result_t umfCudaMemoryProviderParamsDestroy(
24+
umf_cuda_memory_provider_params_handle_t hParams) {
25+
(void)hParams;
26+
return UMF_RESULT_ERROR_NOT_SUPPORTED;
27+
}
28+
29+
umf_result_t umfCudaMemoryProviderParamsSetContext(
30+
umf_cuda_memory_provider_params_handle_t hParams, void *hContext) {
31+
(void)hParams;
32+
(void)hContext;
33+
return UMF_RESULT_ERROR_NOT_SUPPORTED;
34+
}
35+
36+
umf_result_t umfCudaMemoryProviderParamsSetDevice(
37+
umf_cuda_memory_provider_params_handle_t hParams, int hDevice) {
38+
(void)hParams;
39+
(void)hDevice;
40+
return UMF_RESULT_ERROR_NOT_SUPPORTED;
41+
}
42+
43+
umf_result_t umfCudaMemoryProviderParamsSetMemoryType(
44+
umf_cuda_memory_provider_params_handle_t hParams,
45+
umf_usm_memory_type_t memoryType) {
46+
(void)hParams;
47+
(void)memoryType;
48+
return UMF_RESULT_ERROR_NOT_SUPPORTED;
49+
}
50+
1751
umf_memory_provider_ops_t *umfCUDAMemoryProviderOps(void) {
1852
// not supported
1953
return NULL;
@@ -48,6 +82,13 @@ typedef struct cu_memory_provider_t {
4882
size_t min_alignment;
4983
} cu_memory_provider_t;
5084

85+
// CUDA Memory Provider settings struct
86+
typedef struct umf_cuda_memory_provider_params_t {
87+
void *cuda_context_handle; ///< Handle to the CUDA context
88+
int cuda_device_handle; ///< Handle to the CUDA device
89+
umf_usm_memory_type_t memory_type; ///< Allocation memory type
90+
} umf_cuda_memory_provider_params_t;
91+
5192
typedef struct cu_ops_t {
5293
CUresult (*cuMemGetAllocationGranularity)(
5394
size_t *granularity, const CUmemAllocationProp *prop,
@@ -158,14 +199,81 @@ static void init_cu_global_state(void) {
158199
}
159200
}
160201

202+
umf_result_t umfCudaMemoryProviderParamsCreate(
203+
umf_cuda_memory_provider_params_handle_t *hParams) {
204+
if (!hParams) {
205+
LOG_ERR("CUDA memory provider params handle is NULL");
206+
return UMF_RESULT_ERROR_INVALID_ARGUMENT;
207+
}
208+
209+
umf_cuda_memory_provider_params_handle_t params_data =
210+
umf_ba_global_alloc(sizeof(umf_cuda_memory_provider_params_t));
211+
if (!params_data) {
212+
LOG_ERR("Cannot allocate memory for CUDA memory provider params");
213+
return UMF_RESULT_ERROR_OUT_OF_HOST_MEMORY;
214+
}
215+
216+
params_data->cuda_context_handle = NULL;
217+
params_data->cuda_device_handle = -1;
218+
params_data->memory_type = UMF_MEMORY_TYPE_UNKNOWN;
219+
220+
*hParams = params_data;
221+
222+
return UMF_RESULT_SUCCESS;
223+
}
224+
225+
umf_result_t umfCudaMemoryProviderParamsDestroy(
226+
umf_cuda_memory_provider_params_handle_t hParams) {
227+
umf_ba_global_free(hParams);
228+
229+
return UMF_RESULT_SUCCESS;
230+
}
231+
232+
umf_result_t umfCudaMemoryProviderParamsSetContext(
233+
umf_cuda_memory_provider_params_handle_t hParams, void *hContext) {
234+
if (!hParams) {
235+
LOG_ERR("CUDA memory provider params handle is NULL");
236+
return UMF_RESULT_ERROR_INVALID_ARGUMENT;
237+
}
238+
239+
hParams->cuda_context_handle = hContext;
240+
241+
return UMF_RESULT_SUCCESS;
242+
}
243+
244+
umf_result_t umfCudaMemoryProviderParamsSetDevice(
245+
umf_cuda_memory_provider_params_handle_t hParams, int hDevice) {
246+
if (!hParams) {
247+
LOG_ERR("CUDA memory provider params handle is NULL");
248+
return UMF_RESULT_ERROR_INVALID_ARGUMENT;
249+
}
250+
251+
hParams->cuda_device_handle = hDevice;
252+
253+
return UMF_RESULT_SUCCESS;
254+
}
255+
256+
umf_result_t umfCudaMemoryProviderParamsSetMemoryType(
257+
umf_cuda_memory_provider_params_handle_t hParams,
258+
umf_usm_memory_type_t memoryType) {
259+
if (!hParams) {
260+
LOG_ERR("CUDA memory provider params handle is NULL");
261+
return UMF_RESULT_ERROR_INVALID_ARGUMENT;
262+
}
263+
264+
hParams->memory_type = memoryType;
265+
266+
return UMF_RESULT_SUCCESS;
267+
}
268+
161269
static umf_result_t cu_memory_provider_initialize(void *params,
162270
void **provider) {
163271
if (params == NULL) {
164272
return UMF_RESULT_ERROR_INVALID_ARGUMENT;
165273
}
166274

167-
cuda_memory_provider_params_t *cu_params =
168-
(cuda_memory_provider_params_t *)params;
275+
umf_cuda_memory_provider_params_handle_t cu_params =
276+
(umf_cuda_memory_provider_params_handle_t)params;
169277

170278
if (cu_params->memory_type == UMF_MEMORY_TYPE_UNKNOWN ||
171279
cu_params->memory_type > UMF_MEMORY_TYPE_SHARED) {

test/providers/cuda_helpers.cpp

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -355,38 +355,40 @@ int init_cuda() {
355355
return InitResult;
356356
}
357357

358-
cuda_memory_provider_params_t
359-
create_cuda_prov_params(umf_usm_memory_type_t memory_type) {
360-
cuda_memory_provider_params_t params = {NULL, 0, UMF_MEMORY_TYPE_UNKNOWN};
361-
int ret = -1;
358+
int get_cuda_device(CUdevice *device) {
359+
CUdevice cuDevice = -1;
362360

363-
ret = init_cuda();
361+
int ret = init_cuda();
364362
if (ret != 0) {
365-
// Return empty params. Test will be skipped.
366-
return params;
363+
fprintf(stderr, "init_cuda() failed!\n");
364+
return ret;
367365
}
368366

369-
// Get the first CUDA device
370-
CUdevice cuDevice = -1;
371367
CUresult res = libcu_ops.cuDeviceGet(&cuDevice, 0);
372368
if (res != CUDA_SUCCESS || cuDevice < 0) {
373-
// Return empty params. Test will be skipped.
374-
return params;
369+
return -1;
375370
}
376371

377-
// Create a CUDA context
372+
*device = cuDevice;
373+
return 0;
374+
}
375+
376+
int create_context(CUdevice device, CUcontext *context) {
378377
CUcontext cuContext = nullptr;
379-
res = libcu_ops.cuCtxCreate(&cuContext, 0, cuDevice);
380-
if (res != CUDA_SUCCESS || cuContext == nullptr) {
381-
// Return empty params. Test will be skipped.
382-
return params;
378+
379+
int ret = init_cuda();
380+
if (ret != 0) {
381+
fprintf(stderr, "init_cuda() failed!\n");
382+
return ret;
383383
}
384384

385-
params.cuda_context_handle = cuContext;
386-
params.cuda_device_handle = cuDevice;
387-
params.memory_type = memory_type;
385+
CUresult res = libcu_ops.cuCtxCreate(&cuContext, 0, device);
386+
if (res != CUDA_SUCCESS || cuContext == nullptr) {
387+
return -1;
388+
}
388389

389-
return params;
390+
*context = cuContext;
391+
return 0;
390392
}
391393

392394
int destroy_context(CUcontext context) {

test/providers/cuda_helpers.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@
2626
extern "C" {
2727
#endif
2828

29+
int get_cuda_device(CUdevice *device);
30+
31+
int create_context(CUdevice device, CUcontext *context);
32+
2933
int destroy_context(CUcontext context);
3034

3135
int cuda_fill(CUcontext context, CUdevice device, void *ptr, size_t size,
@@ -40,9 +44,6 @@ CUcontext get_mem_context(void *ptr);
4044

4145
CUcontext get_current_context();
4246

43-
cuda_memory_provider_params_t
44-
create_cuda_prov_params(umf_usm_memory_type_t memory_type);
45-
4647
#ifdef __cplusplus
4748
}
4849
#endif

test/providers/ipc_cuda_prov_common.c

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,9 @@
1313
#include "ipc_cuda_prov_common.h"
1414

1515
void memcopy(void *dst, const void *src, size_t size, void *context) {
16-
cuda_memory_provider_params_t *cu_params =
17-
(cuda_memory_provider_params_t *)context;
18-
int ret = cuda_copy(cu_params->cuda_context_handle,
19-
cu_params->cuda_device_handle, dst, (void *)src, size);
16+
cuda_copy_ctx_t *cu_params = (cuda_copy_ctx_t *)context;
17+
int ret = cuda_copy(cu_params->context, cu_params->device, dst, (void *)src,
18+
size);
2019
if (ret != 0) {
2120
fprintf(stderr, "cuda_copy failed with error %d\n", ret);
2221
}

test/providers/ipc_cuda_prov_common.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,11 @@
1010

1111
#include <stddef.h>
1212

13+
typedef struct cuda_copy_ctx_t {
14+
CUcontext context;
15+
CUdevice device;
16+
} cuda_copy_ctx_t;
17+
1318
void memcopy(void *dst, const void *src, size_t size, void *context);
1419

1520
#endif // UMF_TEST_IPC_CUDA_PROV_COMMON_H

0 commit comments

Comments
 (0)