@@ -55,6 +55,22 @@ umf_result_t umfCUDAMemoryProviderParamsSetMemoryType(
55
55
return UMF_RESULT_ERROR_NOT_SUPPORTED ;
56
56
}
57
57
58
+ umf_result_t umfCUDAMemoryProviderParamsSetHostAllocFlags (
59
+ umf_cuda_memory_provider_params_handle_t hParams , unsigned int flags ) {
60
+ (void )hParams ;
61
+ (void )flags ;
62
+ LOG_ERR ("CUDA provider is disabled (UMF_BUILD_CUDA_PROVIDER is OFF)!" );
63
+ return UMF_RESULT_ERROR_NOT_SUPPORTED ;
64
+ }
65
+
66
+ umf_result_t umfCUDAMemoryProviderParamsSetManagedAllocFlags (
67
+ umf_cuda_memory_provider_params_handle_t hParams , unsigned int flags ) {
68
+ (void )hParams ;
69
+ (void )flags ;
70
+ LOG_ERR ("CUDA provider is disabled (UMF_BUILD_CUDA_PROVIDER is OFF)!" );
71
+ return UMF_RESULT_ERROR_NOT_SUPPORTED ;
72
+ }
73
+
58
74
umf_memory_provider_ops_t * umfCUDAMemoryProviderOps (void ) {
59
75
// not supported
60
76
LOG_ERR ("CUDA provider is disabled (UMF_BUILD_CUDA_PROVIDER is OFF)!" );
@@ -89,13 +105,17 @@ typedef struct cu_memory_provider_t {
89
105
CUdevice device ;
90
106
umf_usm_memory_type_t memory_type ;
91
107
size_t min_alignment ;
108
+ unsigned int host_alloc_flags ;
109
+ unsigned int managed_alloc_flags ;
92
110
} cu_memory_provider_t ;
93
111
94
112
// CUDA Memory Provider settings struct
95
113
typedef struct umf_cuda_memory_provider_params_t {
96
- void * cuda_context_handle ; ///< Handle to the CUDA context
97
- int cuda_device_handle ; ///< Handle to the CUDA device
98
- umf_usm_memory_type_t memory_type ; ///< Allocation memory type
114
+ void * cuda_context_handle ; // Handle to the CUDA context
115
+ int cuda_device_handle ; // Handle to the CUDA device
116
+ umf_usm_memory_type_t memory_type ; // Allocation memory type
117
+ unsigned int host_alloc_flags ; // Allocation flags for cuMemHostAlloc
118
+ unsigned int managed_alloc_flags ; // Allocation flags for cuMemAllocManaged
99
119
} umf_cuda_memory_provider_params_t ;
100
120
101
121
typedef struct cu_ops_t {
@@ -104,6 +124,7 @@ typedef struct cu_ops_t {
104
124
CUmemAllocationGranularity_flags option );
105
125
CUresult (* cuMemAlloc )(CUdeviceptr * dptr , size_t bytesize );
106
126
CUresult (* cuMemAllocHost )(void * * pp , size_t bytesize );
127
+ CUresult (* cuMemHostAlloc )(void * * pp , size_t bytesize , unsigned int flags );
107
128
CUresult (* cuMemAllocManaged )(CUdeviceptr * dptr , size_t bytesize ,
108
129
unsigned int flags );
109
130
CUresult (* cuMemFree )(CUdeviceptr dptr );
@@ -175,6 +196,8 @@ static void init_cu_global_state(void) {
175
196
utils_get_symbol_addr (0 , "cuMemAlloc_v2" , lib_name );
176
197
* (void * * )& g_cu_ops .cuMemAllocHost =
177
198
utils_get_symbol_addr (0 , "cuMemAllocHost_v2" , lib_name );
199
+ * (void * * )& g_cu_ops .cuMemHostAlloc =
200
+ utils_get_symbol_addr (0 , "cuMemHostAlloc" , lib_name );
178
201
* (void * * )& g_cu_ops .cuMemAllocManaged =
179
202
utils_get_symbol_addr (0 , "cuMemAllocManaged" , lib_name );
180
203
* (void * * )& g_cu_ops .cuMemFree =
@@ -197,12 +220,12 @@ static void init_cu_global_state(void) {
197
220
utils_get_symbol_addr (0 , "cuIpcCloseMemHandle" , lib_name );
198
221
199
222
if (!g_cu_ops .cuMemGetAllocationGranularity || !g_cu_ops .cuMemAlloc ||
200
- !g_cu_ops .cuMemAllocHost || !g_cu_ops .cuMemAllocManaged ||
201
- !g_cu_ops .cuMemFree || !g_cu_ops .cuMemFreeHost ||
202
- !g_cu_ops .cuGetErrorName || !g_cu_ops .cuGetErrorString ||
203
- !g_cu_ops .cuCtxGetCurrent || !g_cu_ops .cuCtxSetCurrent ||
204
- !g_cu_ops .cuIpcGetMemHandle || !g_cu_ops .cuIpcOpenMemHandle ||
205
- !g_cu_ops .cuIpcCloseMemHandle ) {
223
+ !g_cu_ops .cuMemAllocHost || !g_cu_ops .cuMemHostAlloc ||
224
+ !g_cu_ops .cuMemAllocManaged || !g_cu_ops .cuMemFree ||
225
+ !g_cu_ops .cuMemFreeHost || !g_cu_ops .cuGetErrorName ||
226
+ !g_cu_ops .cuGetErrorString || !g_cu_ops .cuCtxGetCurrent ||
227
+ !g_cu_ops .cuCtxSetCurrent || !g_cu_ops .cuIpcGetMemHandle ||
228
+ !g_cu_ops .cuIpcOpenMemHandle || ! g_cu_ops . cuIpcCloseMemHandle ) {
206
229
LOG_ERR ("Required CUDA symbols not found." );
207
230
Init_cu_global_state_failed = true;
208
231
}
@@ -226,6 +249,8 @@ umf_result_t umfCUDAMemoryProviderParamsCreate(
226
249
params_data -> cuda_context_handle = NULL ;
227
250
params_data -> cuda_device_handle = -1 ;
228
251
params_data -> memory_type = UMF_MEMORY_TYPE_UNKNOWN ;
252
+ params_data -> host_alloc_flags = 0 ;
253
+ params_data -> managed_alloc_flags = CU_MEM_ATTACH_GLOBAL ;
229
254
230
255
* hParams = params_data ;
231
256
@@ -276,6 +301,42 @@ umf_result_t umfCUDAMemoryProviderParamsSetMemoryType(
276
301
return UMF_RESULT_SUCCESS ;
277
302
}
278
303
304
+ umf_result_t umfCUDAMemoryProviderParamsSetHostAllocFlags (
305
+ umf_cuda_memory_provider_params_handle_t hParams , unsigned int flags ) {
306
+ if (!hParams ) {
307
+ LOG_ERR ("CUDA Memory Provider params handle is NULL" );
308
+ return UMF_RESULT_ERROR_INVALID_ARGUMENT ;
309
+ }
310
+
311
+ // mask out valid flags and check if there are bits left
312
+ if (flags & ~(CU_MEMHOSTALLOC_PORTABLE | CU_MEMHOSTALLOC_DEVICEMAP |
313
+ CU_MEMHOSTALLOC_WRITECOMBINED )) {
314
+ LOG_ERR ("Invalid host allocation flags" );
315
+ return UMF_RESULT_ERROR_INVALID_ARGUMENT ;
316
+ }
317
+
318
+ hParams -> host_alloc_flags = flags ;
319
+
320
+ return UMF_RESULT_SUCCESS ;
321
+ }
322
+
323
+ umf_result_t umfCUDAMemoryProviderParamsSetManagedAllocFlags (
324
+ umf_cuda_memory_provider_params_handle_t hParams , unsigned int flags ) {
325
+ if (!hParams ) {
326
+ LOG_ERR ("CUDA Memory Provider params handle is NULL" );
327
+ return UMF_RESULT_ERROR_INVALID_ARGUMENT ;
328
+ }
329
+
330
+ if (flags != CU_MEM_ATTACH_GLOBAL && flags != CU_MEM_ATTACH_HOST ) {
331
+ LOG_ERR ("Invalid managed allocation flags" );
332
+ return UMF_RESULT_ERROR_INVALID_ARGUMENT ;
333
+ }
334
+
335
+ hParams -> managed_alloc_flags = flags ;
336
+
337
+ return UMF_RESULT_SUCCESS ;
338
+ }
339
+
279
340
static umf_result_t cu_memory_provider_initialize (void * params ,
280
341
void * * provider ) {
281
342
if (params == NULL ) {
@@ -325,6 +386,8 @@ static umf_result_t cu_memory_provider_initialize(void *params,
325
386
cu_provider -> device = cu_params -> cuda_device_handle ;
326
387
cu_provider -> memory_type = cu_params -> memory_type ;
327
388
cu_provider -> min_alignment = min_alignment ;
389
+ cu_provider -> host_alloc_flags = cu_params -> host_alloc_flags ;
390
+ cu_provider -> managed_alloc_flags = cu_params -> managed_alloc_flags ;
328
391
329
392
* provider = cu_provider ;
330
393
@@ -382,16 +445,17 @@ static umf_result_t cu_memory_provider_alloc(void *provider, size_t size,
382
445
CUresult cu_result = CUDA_SUCCESS ;
383
446
switch (cu_provider -> memory_type ) {
384
447
case UMF_MEMORY_TYPE_HOST : {
385
- cu_result = g_cu_ops .cuMemAllocHost (resultPtr , size );
448
+ cu_result = g_cu_ops .cuMemHostAlloc (resultPtr , size ,
449
+ cu_provider -> host_alloc_flags );
386
450
break ;
387
451
}
388
452
case UMF_MEMORY_TYPE_DEVICE : {
389
453
cu_result = g_cu_ops .cuMemAlloc ((CUdeviceptr * )resultPtr , size );
390
454
break ;
391
455
}
392
456
case UMF_MEMORY_TYPE_SHARED : {
393
- cu_result = g_cu_ops .cuMemAllocManaged (( CUdeviceptr * ) resultPtr , size ,
394
- CU_MEM_ATTACH_GLOBAL );
457
+ cu_result = g_cu_ops .cuMemAllocManaged (
458
+ ( CUdeviceptr * ) resultPtr , size , cu_provider -> managed_alloc_flags );
395
459
break ;
396
460
}
397
461
default :
0 commit comments