Skip to content

Commit 1160a27

Browse files
authored
Merge pull request #1079 from bratpiorka/rrudnick_cuda_flags
add support for CUDA allocation flags
2 parents 727a796 + d36d585 commit 1160a27

File tree

8 files changed

+224
-22
lines changed

8 files changed

+224
-22
lines changed

include/umf/providers/provider_cuda.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (C) 2024 Intel Corporation
2+
* Copyright (C) 2024-2025 Intel Corporation
33
*
44
* Under the Apache License v2.0 with LLVM Exceptions. See LICENSE.TXT.
55
* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
@@ -53,6 +53,13 @@ umf_result_t umfCUDAMemoryProviderParamsSetMemoryType(
5353
umf_cuda_memory_provider_params_handle_t hParams,
5454
umf_usm_memory_type_t memoryType);
5555

56+
/// @brief Set the allocation flags in the parameters struct.
57+
/// @param hParams handle to the parameters of the CUDA Memory Provider.
58+
/// @param flags valid combination of CUDA allocation flags.
59+
/// @return UMF_RESULT_SUCCESS on success or appropriate error code on failure.
60+
umf_result_t umfCUDAMemoryProviderParamsSetAllocFlags(
61+
umf_cuda_memory_provider_params_handle_t hParams, unsigned int flags);
62+
5663
umf_memory_provider_ops_t *umfCUDAMemoryProviderOps(void);
5764

5865
#ifdef __cplusplus

src/libumf.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ EXPORTS
118118
umfScalablePoolParamsSetGranularity
119119
umfScalablePoolParamsSetKeepAllMemory
120120
; Added in UMF_0.11
121+
umfCUDAMemoryProviderParamsSetAllocFlags
121122
umfFixedMemoryProviderOps
122123
umfFixedMemoryProviderParamsCreate
123124
umfFixedMemoryProviderParamsDestroy

src/libumf.map

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ UMF_0.10 {
116116
};
117117

118118
UMF_0.11 {
119+
umfCUDAMemoryProviderParamsSetAllocFlags;
119120
umfFixedMemoryProviderOps;
120121
umfFixedMemoryProviderParamsCreate;
121122
umfFixedMemoryProviderParamsDestroy;

src/provider/provider_cuda.c

Lines changed: 51 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,14 @@ umf_result_t umfCUDAMemoryProviderParamsSetMemoryType(
5555
return UMF_RESULT_ERROR_NOT_SUPPORTED;
5656
}
5757

58+
umf_result_t umfCUDAMemoryProviderParamsSetAllocFlags(
59+
umf_cuda_memory_provider_params_handle_t hParams, unsigned int flags) {
60+
(void)hParams;
61+
(void)flags;
62+
LOG_ERR("CUDA provider is disabled (UMF_BUILD_CUDA_PROVIDER is OFF)!");
63+
return UMF_RESULT_ERROR_NOT_SUPPORTED;
64+
}
65+
5866
umf_memory_provider_ops_t *umfCUDAMemoryProviderOps(void) {
5967
// not supported
6068
LOG_ERR("CUDA provider is disabled (UMF_BUILD_CUDA_PROVIDER is OFF)!");
@@ -89,21 +97,30 @@ typedef struct cu_memory_provider_t {
8997
CUdevice device;
9098
umf_usm_memory_type_t memory_type;
9199
size_t min_alignment;
100+
unsigned int alloc_flags;
92101
} cu_memory_provider_t;
93102

94103
// CUDA Memory Provider settings struct
95104
typedef struct umf_cuda_memory_provider_params_t {
96-
void *cuda_context_handle; ///< Handle to the CUDA context
97-
int cuda_device_handle; ///< Handle to the CUDA device
98-
umf_usm_memory_type_t memory_type; ///< Allocation memory type
105+
// Handle to the CUDA context
106+
void *cuda_context_handle;
107+
108+
// Handle to the CUDA device
109+
int cuda_device_handle;
110+
111+
// Allocation memory type
112+
umf_usm_memory_type_t memory_type;
113+
114+
// Allocation flags for cuMemHostAlloc/cuMemAllocManaged
115+
unsigned int alloc_flags;
99116
} umf_cuda_memory_provider_params_t;
100117

101118
typedef struct cu_ops_t {
102119
CUresult (*cuMemGetAllocationGranularity)(
103120
size_t *granularity, const CUmemAllocationProp *prop,
104121
CUmemAllocationGranularity_flags option);
105122
CUresult (*cuMemAlloc)(CUdeviceptr *dptr, size_t bytesize);
106-
CUresult (*cuMemAllocHost)(void **pp, size_t bytesize);
123+
CUresult (*cuMemHostAlloc)(void **pp, size_t bytesize, unsigned int flags);
107124
CUresult (*cuMemAllocManaged)(CUdeviceptr *dptr, size_t bytesize,
108125
unsigned int flags);
109126
CUresult (*cuMemFree)(CUdeviceptr dptr);
@@ -172,8 +189,8 @@ static void init_cu_global_state(void) {
172189
utils_get_symbol_addr(0, "cuMemGetAllocationGranularity", lib_name);
173190
*(void **)&g_cu_ops.cuMemAlloc =
174191
utils_get_symbol_addr(0, "cuMemAlloc_v2", lib_name);
175-
*(void **)&g_cu_ops.cuMemAllocHost =
176-
utils_get_symbol_addr(0, "cuMemAllocHost_v2", lib_name);
192+
*(void **)&g_cu_ops.cuMemHostAlloc =
193+
utils_get_symbol_addr(0, "cuMemHostAlloc", lib_name);
177194
*(void **)&g_cu_ops.cuMemAllocManaged =
178195
utils_get_symbol_addr(0, "cuMemAllocManaged", lib_name);
179196
*(void **)&g_cu_ops.cuMemFree =
@@ -196,7 +213,7 @@ static void init_cu_global_state(void) {
196213
utils_get_symbol_addr(0, "cuIpcCloseMemHandle", lib_name);
197214

198215
if (!g_cu_ops.cuMemGetAllocationGranularity || !g_cu_ops.cuMemAlloc ||
199-
!g_cu_ops.cuMemAllocHost || !g_cu_ops.cuMemAllocManaged ||
216+
!g_cu_ops.cuMemHostAlloc || !g_cu_ops.cuMemAllocManaged ||
200217
!g_cu_ops.cuMemFree || !g_cu_ops.cuMemFreeHost ||
201218
!g_cu_ops.cuGetErrorName || !g_cu_ops.cuGetErrorString ||
202219
!g_cu_ops.cuCtxGetCurrent || !g_cu_ops.cuCtxSetCurrent ||
@@ -225,6 +242,7 @@ umf_result_t umfCUDAMemoryProviderParamsCreate(
225242
params_data->cuda_context_handle = NULL;
226243
params_data->cuda_device_handle = -1;
227244
params_data->memory_type = UMF_MEMORY_TYPE_UNKNOWN;
245+
params_data->alloc_flags = 0;
228246

229247
*hParams = params_data;
230248

@@ -275,6 +293,18 @@ umf_result_t umfCUDAMemoryProviderParamsSetMemoryType(
275293
return UMF_RESULT_SUCCESS;
276294
}
277295

296+
umf_result_t umfCUDAMemoryProviderParamsSetAllocFlags(
297+
umf_cuda_memory_provider_params_handle_t hParams, unsigned int flags) {
298+
if (!hParams) {
299+
LOG_ERR("CUDA Memory Provider params handle is NULL");
300+
return UMF_RESULT_ERROR_INVALID_ARGUMENT;
301+
}
302+
303+
hParams->alloc_flags = flags;
304+
305+
return UMF_RESULT_SUCCESS;
306+
}
307+
278308
static umf_result_t cu_memory_provider_initialize(void *params,
279309
void **provider) {
280310
if (params == NULL) {
@@ -325,6 +355,17 @@ static umf_result_t cu_memory_provider_initialize(void *params,
325355
cu_provider->memory_type = cu_params->memory_type;
326356
cu_provider->min_alignment = min_alignment;
327357

358+
// If the memory type is shared (CUDA managed), the allocation flags must
359+
// be set. NOTE: we do not check here if the flags are valid -
360+
// this will be done by CUDA runtime.
361+
if (cu_params->memory_type == UMF_MEMORY_TYPE_SHARED &&
362+
cu_params->alloc_flags == 0) {
363+
// the default setting is CU_MEM_ATTACH_GLOBAL
364+
cu_provider->alloc_flags = CU_MEM_ATTACH_GLOBAL;
365+
} else {
366+
cu_provider->alloc_flags = cu_params->alloc_flags;
367+
}
368+
328369
*provider = cu_provider;
329370

330371
return UMF_RESULT_SUCCESS;
@@ -381,7 +422,8 @@ static umf_result_t cu_memory_provider_alloc(void *provider, size_t size,
381422
CUresult cu_result = CUDA_SUCCESS;
382423
switch (cu_provider->memory_type) {
383424
case UMF_MEMORY_TYPE_HOST: {
384-
cu_result = g_cu_ops.cuMemAllocHost(resultPtr, size);
425+
cu_result =
426+
g_cu_ops.cuMemHostAlloc(resultPtr, size, cu_provider->alloc_flags);
385427
break;
386428
}
387429
case UMF_MEMORY_TYPE_DEVICE: {
@@ -390,7 +432,7 @@ static umf_result_t cu_memory_provider_alloc(void *provider, size_t size,
390432
}
391433
case UMF_MEMORY_TYPE_SHARED: {
392434
cu_result = g_cu_ops.cuMemAllocManaged((CUdeviceptr *)resultPtr, size,
393-
CU_MEM_ATTACH_GLOBAL);
435+
cu_provider->alloc_flags);
394436
break;
395437
}
396438
default:

test/providers/cuda_helpers.cpp

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ struct libcu_ops {
2222
CUresult (*cuDeviceGet)(CUdevice *device, int ordinal);
2323
CUresult (*cuMemAlloc)(CUdeviceptr *dptr, size_t size);
2424
CUresult (*cuMemFree)(CUdeviceptr dptr);
25-
CUresult (*cuMemAllocHost)(void **pp, size_t size);
25+
CUresult (*cuMemHostAlloc)(void **pp, size_t size, unsigned int flags);
2626
CUresult (*cuMemAllocManaged)(CUdeviceptr *dptr, size_t bytesize,
2727
unsigned int flags);
2828
CUresult (*cuMemFreeHost)(void *p);
@@ -34,6 +34,7 @@ struct libcu_ops {
3434
CUresult (*cuPointerGetAttributes)(unsigned int numAttributes,
3535
CUpointer_attribute *attributes,
3636
void **data, CUdeviceptr ptr);
37+
CUresult (*cuMemHostGetFlags)(unsigned int *pFlags, void *p);
3738
CUresult (*cuStreamSynchronize)(CUstream hStream);
3839
CUresult (*cuCtxSynchronize)(void);
3940
} libcu_ops;
@@ -69,7 +70,7 @@ struct DlHandleCloser {
6970
libcu_ops.cuMemFree = [](auto... args) {
7071
return noop_stub(args...);
7172
};
72-
libcu_ops.cuMemAllocHost = [](auto... args) {
73+
libcu_ops.cuMemHostAlloc = [](auto... args) {
7374
return noop_stub(args...);
7475
};
7576
libcu_ops.cuMemAllocManaged = [](auto... args) {
@@ -90,6 +91,9 @@ struct DlHandleCloser {
9091
libcu_ops.cuPointerGetAttributes = [](auto... args) {
9192
return noop_stub(args...);
9293
};
94+
libcu_ops.cuMemHostGetFlags = [](auto... args) {
95+
return noop_stub(args...);
96+
};
9397
libcu_ops.cuStreamSynchronize = [](auto... args) {
9498
return noop_stub(args...);
9599
};
@@ -164,10 +168,10 @@ int InitCUDAOps() {
164168
fprintf(stderr, "cuMemFree_v2 symbol not found in %s\n", lib_name);
165169
return -1;
166170
}
167-
*(void **)&libcu_ops.cuMemAllocHost =
168-
utils_get_symbol_addr(cuDlHandle.get(), "cuMemAllocHost_v2", lib_name);
169-
if (libcu_ops.cuMemAllocHost == nullptr) {
170-
fprintf(stderr, "cuMemAllocHost_v2 symbol not found in %s\n", lib_name);
171+
*(void **)&libcu_ops.cuMemHostAlloc =
172+
utils_get_symbol_addr(cuDlHandle.get(), "cuMemHostAlloc", lib_name);
173+
if (libcu_ops.cuMemHostAlloc == nullptr) {
174+
fprintf(stderr, "cuMemHostAlloc symbol not found in %s\n", lib_name);
171175
return -1;
172176
}
173177
*(void **)&libcu_ops.cuMemAllocManaged =
@@ -208,6 +212,12 @@ int InitCUDAOps() {
208212
lib_name);
209213
return -1;
210214
}
215+
*(void **)&libcu_ops.cuMemHostGetFlags =
216+
utils_get_symbol_addr(cuDlHandle.get(), "cuMemHostGetFlags", lib_name);
217+
if (libcu_ops.cuMemHostGetFlags == nullptr) {
218+
fprintf(stderr, "cuMemHostGetFlags symbol not found in %s\n", lib_name);
219+
return -1;
220+
}
211221
*(void **)&libcu_ops.cuStreamSynchronize = utils_get_symbol_addr(
212222
cuDlHandle.get(), "cuStreamSynchronize", lib_name);
213223
if (libcu_ops.cuStreamSynchronize == nullptr) {
@@ -236,14 +246,15 @@ int InitCUDAOps() {
236246
libcu_ops.cuCtxSetCurrent = cuCtxSetCurrent;
237247
libcu_ops.cuDeviceGet = cuDeviceGet;
238248
libcu_ops.cuMemAlloc = cuMemAlloc;
239-
libcu_ops.cuMemAllocHost = cuMemAllocHost;
249+
libcu_ops.cuMemHostAlloc = cuMemHostAlloc;
240250
libcu_ops.cuMemAllocManaged = cuMemAllocManaged;
241251
libcu_ops.cuMemFree = cuMemFree;
242252
libcu_ops.cuMemFreeHost = cuMemFreeHost;
243253
libcu_ops.cuMemsetD32 = cuMemsetD32;
244254
libcu_ops.cuMemcpy = cuMemcpy;
245255
libcu_ops.cuPointerGetAttribute = cuPointerGetAttribute;
246256
libcu_ops.cuPointerGetAttributes = cuPointerGetAttributes;
257+
libcu_ops.cuMemHostGetFlags = cuMemHostGetFlags;
247258
libcu_ops.cuStreamSynchronize = cuStreamSynchronize;
248259
libcu_ops.cuCtxSynchronize = cuCtxSynchronize;
249260

@@ -373,6 +384,17 @@ umf_usm_memory_type_t get_mem_type(CUcontext context, void *ptr) {
373384
return UMF_MEMORY_TYPE_UNKNOWN;
374385
}
375386

387+
unsigned int get_mem_host_alloc_flags(void *ptr) {
388+
unsigned int flags;
389+
CUresult res = libcu_ops.cuMemHostGetFlags(&flags, ptr);
390+
if (res != CUDA_SUCCESS) {
391+
fprintf(stderr, "cuPointerGetAttribute() failed!\n");
392+
return 0;
393+
}
394+
395+
return flags;
396+
}
397+
376398
CUcontext get_mem_context(void *ptr) {
377399
CUcontext context;
378400
CUresult res = libcu_ops.cuPointerGetAttribute(

test/providers/cuda_helpers.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ int cuda_copy(CUcontext context, CUdevice device, void *dst_ptr,
4444

4545
umf_usm_memory_type_t get_mem_type(CUcontext context, void *ptr);
4646

47+
unsigned int get_mem_host_alloc_flags(void *ptr);
48+
4749
CUcontext get_mem_context(void *ptr);
4850

4951
CUcontext get_current_context();

0 commit comments

Comments
 (0)