Skip to content

add support for CUDA allocation flags #1079

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion include/umf/providers/provider_cuda.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (C) 2024 Intel Corporation
* Copyright (C) 2024-2025 Intel Corporation
*
* Under the Apache License v2.0 with LLVM Exceptions. See LICENSE.TXT.
* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
Expand Down Expand Up @@ -53,6 +53,13 @@ umf_result_t umfCUDAMemoryProviderParamsSetMemoryType(
umf_cuda_memory_provider_params_handle_t hParams,
umf_usm_memory_type_t memoryType);

/// @brief Set the allocation flags in the parameters struct.
/// @param hParams handle to the parameters of the CUDA Memory Provider.
/// @param flags valid combination of CUDA allocation flags.
/// @return UMF_RESULT_SUCCESS on success or appropriate error code on failure.
umf_result_t umfCUDAMemoryProviderParamsSetAllocFlags(
umf_cuda_memory_provider_params_handle_t hParams, unsigned int flags);

umf_memory_provider_ops_t *umfCUDAMemoryProviderOps(void);

#ifdef __cplusplus
Expand Down
1 change: 1 addition & 0 deletions src/libumf.def
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ EXPORTS
umfScalablePoolParamsSetGranularity
umfScalablePoolParamsSetKeepAllMemory
; Added in UMF_0.11
umfCUDAMemoryProviderParamsSetAllocFlags
umfFixedMemoryProviderOps
umfFixedMemoryProviderParamsCreate
umfFixedMemoryProviderParamsDestroy
Expand Down
1 change: 1 addition & 0 deletions src/libumf.map
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ UMF_0.10 {
};

UMF_0.11 {
umfCUDAMemoryProviderParamsSetAllocFlags;
umfFixedMemoryProviderOps;
umfFixedMemoryProviderParamsCreate;
umfFixedMemoryProviderParamsDestroy;
Expand Down
60 changes: 51 additions & 9 deletions src/provider/provider_cuda.c
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,14 @@ umf_result_t umfCUDAMemoryProviderParamsSetMemoryType(
return UMF_RESULT_ERROR_NOT_SUPPORTED;
}

umf_result_t umfCUDAMemoryProviderParamsSetAllocFlags(
umf_cuda_memory_provider_params_handle_t hParams, unsigned int flags) {
(void)hParams;
(void)flags;
LOG_ERR("CUDA provider is disabled (UMF_BUILD_CUDA_PROVIDER is OFF)!");
return UMF_RESULT_ERROR_NOT_SUPPORTED;
}

umf_memory_provider_ops_t *umfCUDAMemoryProviderOps(void) {
// not supported
LOG_ERR("CUDA provider is disabled (UMF_BUILD_CUDA_PROVIDER is OFF)!");
Expand Down Expand Up @@ -89,21 +97,30 @@ typedef struct cu_memory_provider_t {
CUdevice device;
umf_usm_memory_type_t memory_type;
size_t min_alignment;
unsigned int alloc_flags;
} cu_memory_provider_t;

// CUDA Memory Provider settings struct
typedef struct umf_cuda_memory_provider_params_t {
void *cuda_context_handle; ///< Handle to the CUDA context
int cuda_device_handle; ///< Handle to the CUDA device
umf_usm_memory_type_t memory_type; ///< Allocation memory type
// Handle to the CUDA context
void *cuda_context_handle;

// Handle to the CUDA device
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are no longer doxygen style comments. Is that intended since this is implementation and non-public?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this code was moved from the public include files to this location when we updated the UMF API to use setters instead of directly exposign the parameter structure. So this change is just a cleanup

int cuda_device_handle;

// Allocation memory type
umf_usm_memory_type_t memory_type;

// Allocation flags for cuMemHostAlloc/cuMemAllocManaged
unsigned int alloc_flags;
} umf_cuda_memory_provider_params_t;

typedef struct cu_ops_t {
CUresult (*cuMemGetAllocationGranularity)(
size_t *granularity, const CUmemAllocationProp *prop,
CUmemAllocationGranularity_flags option);
CUresult (*cuMemAlloc)(CUdeviceptr *dptr, size_t bytesize);
CUresult (*cuMemAllocHost)(void **pp, size_t bytesize);
CUresult (*cuMemHostAlloc)(void **pp, size_t bytesize, unsigned int flags);
CUresult (*cuMemAllocManaged)(CUdeviceptr *dptr, size_t bytesize,
unsigned int flags);
CUresult (*cuMemFree)(CUdeviceptr dptr);
Expand Down Expand Up @@ -172,8 +189,8 @@ static void init_cu_global_state(void) {
utils_get_symbol_addr(0, "cuMemGetAllocationGranularity", lib_name);
*(void **)&g_cu_ops.cuMemAlloc =
utils_get_symbol_addr(0, "cuMemAlloc_v2", lib_name);
*(void **)&g_cu_ops.cuMemAllocHost =
utils_get_symbol_addr(0, "cuMemAllocHost_v2", lib_name);
*(void **)&g_cu_ops.cuMemHostAlloc =
utils_get_symbol_addr(0, "cuMemHostAlloc", lib_name);
*(void **)&g_cu_ops.cuMemAllocManaged =
utils_get_symbol_addr(0, "cuMemAllocManaged", lib_name);
*(void **)&g_cu_ops.cuMemFree =
Expand All @@ -196,7 +213,7 @@ static void init_cu_global_state(void) {
utils_get_symbol_addr(0, "cuIpcCloseMemHandle", lib_name);

if (!g_cu_ops.cuMemGetAllocationGranularity || !g_cu_ops.cuMemAlloc ||
!g_cu_ops.cuMemAllocHost || !g_cu_ops.cuMemAllocManaged ||
!g_cu_ops.cuMemHostAlloc || !g_cu_ops.cuMemAllocManaged ||
!g_cu_ops.cuMemFree || !g_cu_ops.cuMemFreeHost ||
!g_cu_ops.cuGetErrorName || !g_cu_ops.cuGetErrorString ||
!g_cu_ops.cuCtxGetCurrent || !g_cu_ops.cuCtxSetCurrent ||
Expand Down Expand Up @@ -225,6 +242,7 @@ umf_result_t umfCUDAMemoryProviderParamsCreate(
params_data->cuda_context_handle = NULL;
params_data->cuda_device_handle = -1;
params_data->memory_type = UMF_MEMORY_TYPE_UNKNOWN;
params_data->alloc_flags = 0;

*hParams = params_data;

Expand Down Expand Up @@ -275,6 +293,18 @@ umf_result_t umfCUDAMemoryProviderParamsSetMemoryType(
return UMF_RESULT_SUCCESS;
}

umf_result_t umfCUDAMemoryProviderParamsSetAllocFlags(
umf_cuda_memory_provider_params_handle_t hParams, unsigned int flags) {
if (!hParams) {
LOG_ERR("CUDA Memory Provider params handle is NULL");
return UMF_RESULT_ERROR_INVALID_ARGUMENT;
}

hParams->alloc_flags = flags;

return UMF_RESULT_SUCCESS;
}

static umf_result_t cu_memory_provider_initialize(void *params,
void **provider) {
if (params == NULL) {
Expand Down Expand Up @@ -325,6 +355,17 @@ static umf_result_t cu_memory_provider_initialize(void *params,
cu_provider->memory_type = cu_params->memory_type;
cu_provider->min_alignment = min_alignment;

// If the memory type is shared (CUDA managed), the allocation flags must
// be set. NOTE: we do not check here if the flags are valid -
// this will be done by CUDA runtime.
if (cu_params->memory_type == UMF_MEMORY_TYPE_SHARED &&
cu_params->alloc_flags == 0) {
// the default setting is CU_MEM_ATTACH_GLOBAL
cu_provider->alloc_flags = CU_MEM_ATTACH_GLOBAL;
} else {
cu_provider->alloc_flags = cu_params->alloc_flags;
}

*provider = cu_provider;

return UMF_RESULT_SUCCESS;
Expand Down Expand Up @@ -381,7 +422,8 @@ static umf_result_t cu_memory_provider_alloc(void *provider, size_t size,
CUresult cu_result = CUDA_SUCCESS;
switch (cu_provider->memory_type) {
case UMF_MEMORY_TYPE_HOST: {
cu_result = g_cu_ops.cuMemAllocHost(resultPtr, size);
cu_result =
g_cu_ops.cuMemHostAlloc(resultPtr, size, cu_provider->alloc_flags);
break;
}
case UMF_MEMORY_TYPE_DEVICE: {
Expand All @@ -390,7 +432,7 @@ static umf_result_t cu_memory_provider_alloc(void *provider, size_t size,
}
case UMF_MEMORY_TYPE_SHARED: {
cu_result = g_cu_ops.cuMemAllocManaged((CUdeviceptr *)resultPtr, size,
CU_MEM_ATTACH_GLOBAL);
cu_provider->alloc_flags);
break;
}
default:
Expand Down
36 changes: 29 additions & 7 deletions test/providers/cuda_helpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ struct libcu_ops {
CUresult (*cuDeviceGet)(CUdevice *device, int ordinal);
CUresult (*cuMemAlloc)(CUdeviceptr *dptr, size_t size);
CUresult (*cuMemFree)(CUdeviceptr dptr);
CUresult (*cuMemAllocHost)(void **pp, size_t size);
CUresult (*cuMemHostAlloc)(void **pp, size_t size, unsigned int flags);
CUresult (*cuMemAllocManaged)(CUdeviceptr *dptr, size_t bytesize,
unsigned int flags);
CUresult (*cuMemFreeHost)(void *p);
Expand All @@ -34,6 +34,7 @@ struct libcu_ops {
CUresult (*cuPointerGetAttributes)(unsigned int numAttributes,
CUpointer_attribute *attributes,
void **data, CUdeviceptr ptr);
CUresult (*cuMemHostGetFlags)(unsigned int *pFlags, void *p);
CUresult (*cuStreamSynchronize)(CUstream hStream);
CUresult (*cuCtxSynchronize)(void);
} libcu_ops;
Expand Down Expand Up @@ -69,7 +70,7 @@ struct DlHandleCloser {
libcu_ops.cuMemFree = [](auto... args) {
return noop_stub(args...);
};
libcu_ops.cuMemAllocHost = [](auto... args) {
libcu_ops.cuMemHostAlloc = [](auto... args) {
return noop_stub(args...);
};
libcu_ops.cuMemAllocManaged = [](auto... args) {
Expand All @@ -90,6 +91,9 @@ struct DlHandleCloser {
libcu_ops.cuPointerGetAttributes = [](auto... args) {
return noop_stub(args...);
};
libcu_ops.cuMemHostGetFlags = [](auto... args) {
return noop_stub(args...);
};
libcu_ops.cuStreamSynchronize = [](auto... args) {
return noop_stub(args...);
};
Expand Down Expand Up @@ -164,10 +168,10 @@ int InitCUDAOps() {
fprintf(stderr, "cuMemFree_v2 symbol not found in %s\n", lib_name);
return -1;
}
*(void **)&libcu_ops.cuMemAllocHost =
utils_get_symbol_addr(cuDlHandle.get(), "cuMemAllocHost_v2", lib_name);
if (libcu_ops.cuMemAllocHost == nullptr) {
fprintf(stderr, "cuMemAllocHost_v2 symbol not found in %s\n", lib_name);
*(void **)&libcu_ops.cuMemHostAlloc =
utils_get_symbol_addr(cuDlHandle.get(), "cuMemHostAlloc", lib_name);
if (libcu_ops.cuMemHostAlloc == nullptr) {
fprintf(stderr, "cuMemHostAlloc symbol not found in %s\n", lib_name);
return -1;
}
*(void **)&libcu_ops.cuMemAllocManaged =
Expand Down Expand Up @@ -208,6 +212,12 @@ int InitCUDAOps() {
lib_name);
return -1;
}
*(void **)&libcu_ops.cuMemHostGetFlags =
utils_get_symbol_addr(cuDlHandle.get(), "cuMemHostGetFlags", lib_name);
if (libcu_ops.cuMemHostGetFlags == nullptr) {
fprintf(stderr, "cuMemHostGetFlags symbol not found in %s\n", lib_name);
return -1;
}
*(void **)&libcu_ops.cuStreamSynchronize = utils_get_symbol_addr(
cuDlHandle.get(), "cuStreamSynchronize", lib_name);
if (libcu_ops.cuStreamSynchronize == nullptr) {
Expand Down Expand Up @@ -236,14 +246,15 @@ int InitCUDAOps() {
libcu_ops.cuCtxSetCurrent = cuCtxSetCurrent;
libcu_ops.cuDeviceGet = cuDeviceGet;
libcu_ops.cuMemAlloc = cuMemAlloc;
libcu_ops.cuMemAllocHost = cuMemAllocHost;
libcu_ops.cuMemHostAlloc = cuMemHostAlloc;
libcu_ops.cuMemAllocManaged = cuMemAllocManaged;
libcu_ops.cuMemFree = cuMemFree;
libcu_ops.cuMemFreeHost = cuMemFreeHost;
libcu_ops.cuMemsetD32 = cuMemsetD32;
libcu_ops.cuMemcpy = cuMemcpy;
libcu_ops.cuPointerGetAttribute = cuPointerGetAttribute;
libcu_ops.cuPointerGetAttributes = cuPointerGetAttributes;
libcu_ops.cuMemHostGetFlags = cuMemHostGetFlags;
libcu_ops.cuStreamSynchronize = cuStreamSynchronize;
libcu_ops.cuCtxSynchronize = cuCtxSynchronize;

Expand Down Expand Up @@ -373,6 +384,17 @@ umf_usm_memory_type_t get_mem_type(CUcontext context, void *ptr) {
return UMF_MEMORY_TYPE_UNKNOWN;
}

unsigned int get_mem_host_alloc_flags(void *ptr) {
unsigned int flags;
CUresult res = libcu_ops.cuMemHostGetFlags(&flags, ptr);
if (res != CUDA_SUCCESS) {
fprintf(stderr, "cuPointerGetAttribute() failed!\n");
return 0;
}

return flags;
}

CUcontext get_mem_context(void *ptr) {
CUcontext context;
CUresult res = libcu_ops.cuPointerGetAttribute(
Expand Down
2 changes: 2 additions & 0 deletions test/providers/cuda_helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ int cuda_copy(CUcontext context, CUdevice device, void *dst_ptr,

umf_usm_memory_type_t get_mem_type(CUcontext context, void *ptr);

unsigned int get_mem_host_alloc_flags(void *ptr);

CUcontext get_mem_context(void *ptr);

CUcontext get_current_context();
Expand Down
Loading
Loading