Skip to content

Commit 9d300d4

Browse files
committed
add support for CUDA allocation flags
1 parent 1fa3f8a commit 9d300d4

File tree

8 files changed

+185
-18
lines changed

8 files changed

+185
-18
lines changed

include/umf/providers/provider_cuda.h

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (C) 2024 Intel Corporation
2+
* Copyright (C) 2024-2025 Intel Corporation
33
*
44
* Under the Apache License v2.0 with LLVM Exceptions. See LICENSE.TXT.
55
* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
@@ -53,6 +53,22 @@ umf_result_t umfCUDAMemoryProviderParamsSetMemoryType(
5353
umf_cuda_memory_provider_params_handle_t hParams,
5454
umf_usm_memory_type_t memoryType);
5555

56+
/// @brief Set the host allocation flags in the parameters struct.
57+
/// @param hParams handle to the parameters of the CUDA Memory Provider.
58+
/// @param flags combination of CU_MEMHOSTALLOC_PORTABLE,
59+
/// CU_MEMHOSTALLOC_DEVICEMAP and CU_MEMHOSTALLOC_WRITECOMBINED flags
60+
/// @return UMF_RESULT_SUCCESS on success or appropriate error code on failure.
61+
umf_result_t umfCUDAMemoryProviderParamsSetHostAllocFlags(
62+
umf_cuda_memory_provider_params_handle_t hParams, unsigned int flags);
63+
64+
/// @brief Set the managed allocation flags in the parameters struct.
65+
/// @param hParams handle to the parameters of the CUDA Memory Provider.
66+
/// @param flags must be one of CU_MEM_ATTACH_GLOBAL or CU_MEM_ATTACH_HOST
67+
/// flags
68+
/// @return UMF_RESULT_SUCCESS on success or appropriate error code on failure.
69+
umf_result_t umfCUDAMemoryProviderParamsSetManagedAllocFlags(
70+
umf_cuda_memory_provider_params_handle_t hParams, unsigned int flags);
71+
5672
umf_memory_provider_ops_t *umfCUDAMemoryProviderOps(void);
5773

5874
#ifdef __cplusplus

src/libumf.def

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,8 @@ EXPORTS
118118
umfScalablePoolParamsSetGranularity
119119
umfScalablePoolParamsSetKeepAllMemory
120120
; Added in UMF_0.11
121+
umfCUDAMemoryProviderParamsSetHostAllocFlags
122+
umfCUDAMemoryProviderParamsSetManagedAllocFlags
121123
umfFixedMemoryProviderOps
122124
umfFixedMemoryProviderParamsCreate
123125
umfFixedMemoryProviderParamsDestroy

src/libumf.map

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,8 @@ UMF_0.10 {
116116
};
117117

118118
UMF_0.11 {
119+
umfCUDAMemoryProviderParamsSetHostAllocFlags;
120+
umfCUDAMemoryProviderParamsSetManagedAllocFlags;
119121
umfFixedMemoryProviderOps;
120122
umfFixedMemoryProviderParamsCreate;
121123
umfFixedMemoryProviderParamsDestroy;

src/provider/provider_cuda.c

Lines changed: 76 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,22 @@ umf_result_t umfCUDAMemoryProviderParamsSetMemoryType(
5555
return UMF_RESULT_ERROR_NOT_SUPPORTED;
5656
}
5757

58+
umf_result_t umfCUDAMemoryProviderParamsSetHostAllocFlags(
59+
umf_cuda_memory_provider_params_handle_t hParams, unsigned int flags) {
60+
(void)hParams;
61+
(void)flags;
62+
LOG_ERR("CUDA provider is disabled (UMF_BUILD_CUDA_PROVIDER is OFF)!");
63+
return UMF_RESULT_ERROR_NOT_SUPPORTED;
64+
}
65+
66+
umf_result_t umfCUDAMemoryProviderParamsSetManagedAllocFlags(
67+
umf_cuda_memory_provider_params_handle_t hParams, unsigned int flags) {
68+
(void)hParams;
69+
(void)flags;
70+
LOG_ERR("CUDA provider is disabled (UMF_BUILD_CUDA_PROVIDER is OFF)!");
71+
return UMF_RESULT_ERROR_NOT_SUPPORTED;
72+
}
73+
5874
umf_memory_provider_ops_t *umfCUDAMemoryProviderOps(void) {
5975
// not supported
6076
LOG_ERR("CUDA provider is disabled (UMF_BUILD_CUDA_PROVIDER is OFF)!");
@@ -89,13 +105,17 @@ typedef struct cu_memory_provider_t {
89105
CUdevice device;
90106
umf_usm_memory_type_t memory_type;
91107
size_t min_alignment;
108+
unsigned int host_alloc_flags;
109+
unsigned int managed_alloc_flags;
92110
} cu_memory_provider_t;
93111

94112
// CUDA Memory Provider settings struct
95113
typedef struct umf_cuda_memory_provider_params_t {
96-
void *cuda_context_handle; ///< Handle to the CUDA context
97-
int cuda_device_handle; ///< Handle to the CUDA device
98-
umf_usm_memory_type_t memory_type; ///< Allocation memory type
114+
void *cuda_context_handle; // Handle to the CUDA context
115+
int cuda_device_handle; // Handle to the CUDA device
116+
umf_usm_memory_type_t memory_type; // Allocation memory type
117+
unsigned int host_alloc_flags; // Allocation flags for cuMemHostAlloc
118+
unsigned int managed_alloc_flags; // Allocation flags for cuMemAllocManaged
99119
} umf_cuda_memory_provider_params_t;
100120

101121
typedef struct cu_ops_t {
@@ -104,6 +124,7 @@ typedef struct cu_ops_t {
104124
CUmemAllocationGranularity_flags option);
105125
CUresult (*cuMemAlloc)(CUdeviceptr *dptr, size_t bytesize);
106126
CUresult (*cuMemAllocHost)(void **pp, size_t bytesize);
127+
CUresult (*cuMemHostAlloc)(void **pp, size_t bytesize, unsigned int flags);
107128
CUresult (*cuMemAllocManaged)(CUdeviceptr *dptr, size_t bytesize,
108129
unsigned int flags);
109130
CUresult (*cuMemFree)(CUdeviceptr dptr);
@@ -175,6 +196,8 @@ static void init_cu_global_state(void) {
175196
utils_get_symbol_addr(0, "cuMemAlloc_v2", lib_name);
176197
*(void **)&g_cu_ops.cuMemAllocHost =
177198
utils_get_symbol_addr(0, "cuMemAllocHost_v2", lib_name);
199+
*(void **)&g_cu_ops.cuMemHostAlloc =
200+
utils_get_symbol_addr(0, "cuMemHostAlloc", lib_name);
178201
*(void **)&g_cu_ops.cuMemAllocManaged =
179202
utils_get_symbol_addr(0, "cuMemAllocManaged", lib_name);
180203
*(void **)&g_cu_ops.cuMemFree =
@@ -197,12 +220,12 @@ static void init_cu_global_state(void) {
197220
utils_get_symbol_addr(0, "cuIpcCloseMemHandle", lib_name);
198221

199222
if (!g_cu_ops.cuMemGetAllocationGranularity || !g_cu_ops.cuMemAlloc ||
200-
!g_cu_ops.cuMemAllocHost || !g_cu_ops.cuMemAllocManaged ||
201-
!g_cu_ops.cuMemFree || !g_cu_ops.cuMemFreeHost ||
202-
!g_cu_ops.cuGetErrorName || !g_cu_ops.cuGetErrorString ||
203-
!g_cu_ops.cuCtxGetCurrent || !g_cu_ops.cuCtxSetCurrent ||
204-
!g_cu_ops.cuIpcGetMemHandle || !g_cu_ops.cuIpcOpenMemHandle ||
205-
!g_cu_ops.cuIpcCloseMemHandle) {
223+
!g_cu_ops.cuMemAllocHost || !g_cu_ops.cuMemHostAlloc ||
224+
!g_cu_ops.cuMemAllocManaged || !g_cu_ops.cuMemFree ||
225+
!g_cu_ops.cuMemFreeHost || !g_cu_ops.cuGetErrorName ||
226+
!g_cu_ops.cuGetErrorString || !g_cu_ops.cuCtxGetCurrent ||
227+
!g_cu_ops.cuCtxSetCurrent || !g_cu_ops.cuIpcGetMemHandle ||
228+
!g_cu_ops.cuIpcOpenMemHandle || !g_cu_ops.cuIpcCloseMemHandle) {
206229
LOG_ERR("Required CUDA symbols not found.");
207230
Init_cu_global_state_failed = true;
208231
}
@@ -226,6 +249,8 @@ umf_result_t umfCUDAMemoryProviderParamsCreate(
226249
params_data->cuda_context_handle = NULL;
227250
params_data->cuda_device_handle = -1;
228251
params_data->memory_type = UMF_MEMORY_TYPE_UNKNOWN;
252+
params_data->host_alloc_flags = 0;
253+
params_data->managed_alloc_flags = CU_MEM_ATTACH_GLOBAL;
229254

230255
*hParams = params_data;
231256

@@ -276,6 +301,42 @@ umf_result_t umfCUDAMemoryProviderParamsSetMemoryType(
276301
return UMF_RESULT_SUCCESS;
277302
}
278303

304+
umf_result_t umfCUDAMemoryProviderParamsSetHostAllocFlags(
305+
umf_cuda_memory_provider_params_handle_t hParams, unsigned int flags) {
306+
if (!hParams) {
307+
LOG_ERR("CUDA Memory Provider params handle is NULL");
308+
return UMF_RESULT_ERROR_INVALID_ARGUMENT;
309+
}
310+
311+
// mask out valid flags and check if there are bits left
312+
if (flags & ~(CU_MEMHOSTALLOC_PORTABLE | CU_MEMHOSTALLOC_DEVICEMAP |
313+
CU_MEMHOSTALLOC_WRITECOMBINED)) {
314+
LOG_ERR("Invalid host allocation flags");
315+
return UMF_RESULT_ERROR_INVALID_ARGUMENT;
316+
}
317+
318+
hParams->host_alloc_flags = flags;
319+
320+
return UMF_RESULT_SUCCESS;
321+
}
322+
323+
umf_result_t umfCUDAMemoryProviderParamsSetManagedAllocFlags(
324+
umf_cuda_memory_provider_params_handle_t hParams, unsigned int flags) {
325+
if (!hParams) {
326+
LOG_ERR("CUDA Memory Provider params handle is NULL");
327+
return UMF_RESULT_ERROR_INVALID_ARGUMENT;
328+
}
329+
330+
if (flags != CU_MEM_ATTACH_GLOBAL && flags != CU_MEM_ATTACH_HOST) {
331+
LOG_ERR("Invalid managed allocation flags");
332+
return UMF_RESULT_ERROR_INVALID_ARGUMENT;
333+
}
334+
335+
hParams->managed_alloc_flags = flags;
336+
337+
return UMF_RESULT_SUCCESS;
338+
}
339+
279340
static umf_result_t cu_memory_provider_initialize(void *params,
280341
void **provider) {
281342
if (params == NULL) {
@@ -325,6 +386,8 @@ static umf_result_t cu_memory_provider_initialize(void *params,
325386
cu_provider->device = cu_params->cuda_device_handle;
326387
cu_provider->memory_type = cu_params->memory_type;
327388
cu_provider->min_alignment = min_alignment;
389+
cu_provider->host_alloc_flags = cu_params->host_alloc_flags;
390+
cu_provider->managed_alloc_flags = cu_params->managed_alloc_flags;
328391

329392
*provider = cu_provider;
330393

@@ -382,16 +445,17 @@ static umf_result_t cu_memory_provider_alloc(void *provider, size_t size,
382445
CUresult cu_result = CUDA_SUCCESS;
383446
switch (cu_provider->memory_type) {
384447
case UMF_MEMORY_TYPE_HOST: {
385-
cu_result = g_cu_ops.cuMemAllocHost(resultPtr, size);
448+
cu_result = g_cu_ops.cuMemHostAlloc(resultPtr, size,
449+
cu_provider->host_alloc_flags);
386450
break;
387451
}
388452
case UMF_MEMORY_TYPE_DEVICE: {
389453
cu_result = g_cu_ops.cuMemAlloc((CUdeviceptr *)resultPtr, size);
390454
break;
391455
}
392456
case UMF_MEMORY_TYPE_SHARED: {
393-
cu_result = g_cu_ops.cuMemAllocManaged((CUdeviceptr *)resultPtr, size,
394-
CU_MEM_ATTACH_GLOBAL);
457+
cu_result = g_cu_ops.cuMemAllocManaged(
458+
(CUdeviceptr *)resultPtr, size, cu_provider->managed_alloc_flags);
395459
break;
396460
}
397461
default:

test/providers/cuda_helpers.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ struct libcu_ops {
2323
CUresult (*cuMemAlloc)(CUdeviceptr *dptr, size_t size);
2424
CUresult (*cuMemFree)(CUdeviceptr dptr);
2525
CUresult (*cuMemAllocHost)(void **pp, size_t size);
26+
CUresult (*cuMemHostAlloc)(void **pp, size_t size, unsigned int flags);
2627
CUresult (*cuMemAllocManaged)(CUdeviceptr *dptr, size_t bytesize,
2728
unsigned int flags);
2829
CUresult (*cuMemFreeHost)(void *p);
@@ -34,6 +35,7 @@ struct libcu_ops {
3435
CUresult (*cuPointerGetAttributes)(unsigned int numAttributes,
3536
CUpointer_attribute *attributes,
3637
void **data, CUdeviceptr ptr);
38+
CUresult (*cuMemHostGetFlags)(unsigned int *pFlags, void *p);
3739
CUresult (*cuStreamSynchronize)(CUstream hStream);
3840
CUresult (*cuCtxSynchronize)(void);
3941
} libcu_ops;
@@ -72,6 +74,9 @@ struct DlHandleCloser {
7274
libcu_ops.cuMemAllocHost = [](auto... args) {
7375
return noop_stub(args...);
7476
};
77+
libcu_ops.cuMemHostAlloc = [](auto... args) {
78+
return noop_stub(args...);
79+
};
7580
libcu_ops.cuMemAllocManaged = [](auto... args) {
7681
return noop_stub(args...);
7782
};
@@ -90,6 +95,9 @@ struct DlHandleCloser {
9095
libcu_ops.cuPointerGetAttributes = [](auto... args) {
9196
return noop_stub(args...);
9297
};
98+
libcu_ops.cuMemHostGetFlags = [](auto... args) {
99+
return noop_stub(args...);
100+
};
93101
libcu_ops.cuStreamSynchronize = [](auto... args) {
94102
return noop_stub(args...);
95103
};
@@ -170,6 +178,12 @@ int InitCUDAOps() {
170178
fprintf(stderr, "cuMemAllocHost_v2 symbol not found in %s\n", lib_name);
171179
return -1;
172180
}
181+
*(void **)&libcu_ops.cuMemHostAlloc =
182+
utils_get_symbol_addr(cuDlHandle.get(), "cuMemHostAlloc", lib_name);
183+
if (libcu_ops.cuMemHostAlloc == nullptr) {
184+
fprintf(stderr, "cuMemHostAlloc symbol not found in %s\n", lib_name);
185+
return -1;
186+
}
173187
*(void **)&libcu_ops.cuMemAllocManaged =
174188
utils_get_symbol_addr(cuDlHandle.get(), "cuMemAllocManaged", lib_name);
175189
if (libcu_ops.cuMemAllocManaged == nullptr) {
@@ -208,6 +222,12 @@ int InitCUDAOps() {
208222
lib_name);
209223
return -1;
210224
}
225+
*(void **)&libcu_ops.cuMemHostGetFlags =
226+
utils_get_symbol_addr(cuDlHandle.get(), "cuMemHostGetFlags", lib_name);
227+
if (libcu_ops.cuMemHostGetFlags == nullptr) {
228+
fprintf(stderr, "cuMemHostGetFlags symbol not found in %s\n", lib_name);
229+
return -1;
230+
}
211231
*(void **)&libcu_ops.cuStreamSynchronize = utils_get_symbol_addr(
212232
cuDlHandle.get(), "cuStreamSynchronize", lib_name);
213233
if (libcu_ops.cuStreamSynchronize == nullptr) {
@@ -237,13 +257,15 @@ int InitCUDAOps() {
237257
libcu_ops.cuDeviceGet = cuDeviceGet;
238258
libcu_ops.cuMemAlloc = cuMemAlloc;
239259
libcu_ops.cuMemAllocHost = cuMemAllocHost;
260+
libcu_ops.cuMemHostAlloc = cuMemHostAlloc;
240261
libcu_ops.cuMemAllocManaged = cuMemAllocManaged;
241262
libcu_ops.cuMemFree = cuMemFree;
242263
libcu_ops.cuMemFreeHost = cuMemFreeHost;
243264
libcu_ops.cuMemsetD32 = cuMemsetD32;
244265
libcu_ops.cuMemcpy = cuMemcpy;
245266
libcu_ops.cuPointerGetAttribute = cuPointerGetAttribute;
246267
libcu_ops.cuPointerGetAttributes = cuPointerGetAttributes;
268+
libcu_ops.cuMemHostGetFlags = cuMemHostGetFlags;
247269
libcu_ops.cuStreamSynchronize = cuStreamSynchronize;
248270
libcu_ops.cuCtxSynchronize = cuCtxSynchronize;
249271

@@ -373,6 +395,17 @@ umf_usm_memory_type_t get_mem_type(CUcontext context, void *ptr) {
373395
return UMF_MEMORY_TYPE_UNKNOWN;
374396
}
375397

398+
unsigned int get_mem_host_alloc_flags(void *ptr) {
399+
unsigned int flags;
400+
CUresult res = libcu_ops.cuMemHostGetFlags(&flags, ptr);
401+
if (res != CUDA_SUCCESS) {
402+
fprintf(stderr, "cuPointerGetAttribute() failed!\n");
403+
return 0;
404+
}
405+
406+
return flags;
407+
}
408+
376409
CUcontext get_mem_context(void *ptr) {
377410
CUcontext context;
378411
CUresult res = libcu_ops.cuPointerGetAttribute(

test/providers/cuda_helpers.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ int cuda_copy(CUcontext context, CUdevice device, void *dst_ptr,
4242

4343
umf_usm_memory_type_t get_mem_type(CUcontext context, void *ptr);
4444

45+
unsigned int get_mem_host_alloc_flags(void *ptr);
46+
4547
CUcontext get_mem_context(void *ptr);
4648

4749
CUcontext get_current_context();

0 commit comments

Comments
 (0)