Skip to content

Commit b8088be

Browse files
Merge pull request #810 from vinser52/svinogra_cuda_fix
Fix CUDA provider
2 parents 7cbe2e9 + b20cf01 commit b8088be

File tree

4 files changed

+141
-48
lines changed

4 files changed

+141
-48
lines changed

src/provider/provider_cuda.c

Lines changed: 55 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ typedef struct cu_ops_t {
5151

5252
CUresult (*cuGetErrorName)(CUresult error, const char **pStr);
5353
CUresult (*cuGetErrorString)(CUresult error, const char **pStr);
54+
CUresult (*cuCtxGetCurrent)(CUcontext *pctx);
55+
CUresult (*cuCtxSetCurrent)(CUcontext ctx);
5456
} cu_ops_t;
5557

5658
static cu_ops_t g_cu_ops;
@@ -117,11 +119,16 @@ static void init_cu_global_state(void) {
117119
utils_get_symbol_addr(0, "cuGetErrorName", lib_name);
118120
*(void **)&g_cu_ops.cuGetErrorString =
119121
utils_get_symbol_addr(0, "cuGetErrorString", lib_name);
122+
*(void **)&g_cu_ops.cuCtxGetCurrent =
123+
utils_get_symbol_addr(0, "cuCtxGetCurrent", lib_name);
124+
*(void **)&g_cu_ops.cuCtxSetCurrent =
125+
utils_get_symbol_addr(0, "cuCtxSetCurrent", lib_name);
120126

121127
if (!g_cu_ops.cuMemGetAllocationGranularity || !g_cu_ops.cuMemAlloc ||
122128
!g_cu_ops.cuMemAllocHost || !g_cu_ops.cuMemAllocManaged ||
123129
!g_cu_ops.cuMemFree || !g_cu_ops.cuMemFreeHost ||
124-
!g_cu_ops.cuGetErrorName || !g_cu_ops.cuGetErrorString) {
130+
!g_cu_ops.cuGetErrorName || !g_cu_ops.cuGetErrorString ||
131+
!g_cu_ops.cuCtxGetCurrent || !g_cu_ops.cuCtxSetCurrent) {
125132
LOG_ERR("Required CUDA symbols not found.");
126133
Init_cu_global_state_failed = true;
127134
}
@@ -190,6 +197,31 @@ static void cu_memory_provider_finalize(void *provider) {
190197
umf_ba_global_free(provider);
191198
}
192199

200+
/*
201+
* This function is used by the CUDA provider to make sure that
202+
* the required context is set. If the current context is
203+
* not the required one, it will be saved in restore_ctx.
204+
*/
205+
static inline umf_result_t set_context(CUcontext required_ctx,
206+
CUcontext *restore_ctx) {
207+
CUcontext current_ctx = NULL;
208+
CUresult cu_result = g_cu_ops.cuCtxGetCurrent(&current_ctx);
209+
if (cu_result != CUDA_SUCCESS) {
210+
LOG_ERR("cuCtxGetCurrent() failed.");
211+
return cu2umf_result(cu_result);
212+
}
213+
*restore_ctx = current_ctx;
214+
if (current_ctx != required_ctx) {
215+
cu_result = g_cu_ops.cuCtxSetCurrent(required_ctx);
216+
if (cu_result != CUDA_SUCCESS) {
217+
LOG_ERR("cuCtxSetCurrent() failed.");
218+
return cu2umf_result(cu_result);
219+
}
220+
}
221+
222+
return UMF_RESULT_SUCCESS;
223+
}
224+
193225
static umf_result_t cu_memory_provider_alloc(void *provider, size_t size,
194226
size_t alignment,
195227
void **resultPtr) {
@@ -205,6 +237,14 @@ static umf_result_t cu_memory_provider_alloc(void *provider, size_t size,
205237
return UMF_RESULT_ERROR_NOT_SUPPORTED;
206238
}
207239

240+
// Remember current context and set the one from the provider
241+
CUcontext restore_ctx = NULL;
242+
umf_result_t umf_result = set_context(cu_provider->context, &restore_ctx);
243+
if (umf_result != UMF_RESULT_SUCCESS) {
244+
LOG_ERR("Failed to set CUDA context, ret = %d", umf_result);
245+
return umf_result;
246+
}
247+
208248
CUresult cu_result = CUDA_SUCCESS;
209249
switch (cu_provider->memory_type) {
210250
case UMF_MEMORY_TYPE_HOST: {
@@ -224,17 +264,29 @@ static umf_result_t cu_memory_provider_alloc(void *provider, size_t size,
224264
// this shouldn't happen as we check the memory_type settings during
225265
// the initialization
226266
LOG_ERR("unsupported USM memory type");
267+
assert(false);
227268
return UMF_RESULT_ERROR_UNKNOWN;
228269
}
229270

271+
umf_result = set_context(restore_ctx, &restore_ctx);
272+
if (umf_result != UMF_RESULT_SUCCESS) {
273+
LOG_ERR("Failed to restore CUDA context, ret = %d", umf_result);
274+
}
275+
276+
umf_result = cu2umf_result(cu_result);
277+
if (umf_result != UMF_RESULT_SUCCESS) {
278+
LOG_ERR("Failed to allocate memory, cu_result = %d, ret = %d",
279+
cu_result, umf_result);
280+
return umf_result;
281+
}
282+
230283
// check the alignment
231284
if (alignment > 0 && ((uintptr_t)(*resultPtr) % alignment) != 0) {
232285
cu_memory_provider_free(provider, *resultPtr, size);
233286
LOG_ERR("unsupported alignment size");
234287
return UMF_RESULT_ERROR_INVALID_ALIGNMENT;
235288
}
236-
237-
return cu2umf_result(cu_result);
289+
return umf_result;
238290
}
239291

240292
static umf_result_t cu_memory_provider_free(void *provider, void *ptr,

test/providers/cuda_helpers.cpp

Lines changed: 50 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ struct libcu_ops {
1717
CUresult (*cuInit)(unsigned int flags);
1818
CUresult (*cuCtxCreate)(CUcontext *pctx, unsigned int flags, CUdevice dev);
1919
CUresult (*cuCtxDestroy)(CUcontext ctx);
20+
CUresult (*cuCtxGetCurrent)(CUcontext *pctx);
2021
CUresult (*cuDeviceGet)(CUdevice *device, int ordinal);
2122
CUresult (*cuMemAlloc)(CUdeviceptr *dptr, size_t size);
2223
CUresult (*cuMemFree)(CUdeviceptr dptr);
@@ -26,7 +27,9 @@ struct libcu_ops {
2627
CUresult (*cuMemFreeHost)(void *p);
2728
CUresult (*cuMemsetD32)(CUdeviceptr dstDevice, unsigned int pattern,
2829
size_t size);
29-
CUresult (*cuMemcpyDtoH)(void *dstHost, CUdeviceptr srcDevice, size_t size);
30+
CUresult (*cuMemcpy)(CUdeviceptr dst, CUdeviceptr src, size_t size);
31+
CUresult (*cuPointerGetAttribute)(void *data, CUpointer_attribute attribute,
32+
CUdeviceptr ptr);
3033
CUresult (*cuPointerGetAttributes)(unsigned int numAttributes,
3134
CUpointer_attribute *attributes,
3235
void **data, CUdeviceptr ptr);
@@ -74,6 +77,12 @@ int InitCUDAOps() {
7477
fprintf(stderr, "cuCtxDestroy symbol not found in %s\n", lib_name);
7578
return -1;
7679
}
80+
*(void **)&libcu_ops.cuCtxGetCurrent =
81+
utils_get_symbol_addr(cuDlHandle.get(), "cuCtxGetCurrent", lib_name);
82+
if (libcu_ops.cuCtxGetCurrent == nullptr) {
83+
fprintf(stderr, "cuCtxGetCurrent symbol not found in %s\n", lib_name);
84+
return -1;
85+
}
7786
*(void **)&libcu_ops.cuDeviceGet =
7887
utils_get_symbol_addr(cuDlHandle.get(), "cuDeviceGet", lib_name);
7988
if (libcu_ops.cuDeviceGet == nullptr) {
@@ -116,10 +125,17 @@ int InitCUDAOps() {
116125
fprintf(stderr, "cuMemsetD32_v2 symbol not found in %s\n", lib_name);
117126
return -1;
118127
}
119-
*(void **)&libcu_ops.cuMemcpyDtoH =
120-
utils_get_symbol_addr(cuDlHandle.get(), "cuMemcpyDtoH_v2", lib_name);
121-
if (libcu_ops.cuMemcpyDtoH == nullptr) {
122-
fprintf(stderr, "cuMemcpyDtoH_v2 symbol not found in %s\n", lib_name);
128+
*(void **)&libcu_ops.cuMemcpy =
129+
utils_get_symbol_addr(cuDlHandle.get(), "cuMemcpy", lib_name);
130+
if (libcu_ops.cuMemcpy == nullptr) {
131+
fprintf(stderr, "cuMemcpy symbol not found in %s\n", lib_name);
132+
return -1;
133+
}
134+
*(void **)&libcu_ops.cuPointerGetAttribute = utils_get_symbol_addr(
135+
cuDlHandle.get(), "cuPointerGetAttribute", lib_name);
136+
if (libcu_ops.cuPointerGetAttribute == nullptr) {
137+
fprintf(stderr, "cuPointerGetAttribute symbol not found in %s\n",
138+
lib_name);
123139
return -1;
124140
}
125141
*(void **)&libcu_ops.cuPointerGetAttributes = utils_get_symbol_addr(
@@ -140,14 +156,16 @@ int InitCUDAOps() {
140156
libcu_ops.cuInit = cuInit;
141157
libcu_ops.cuCtxCreate = cuCtxCreate;
142158
libcu_ops.cuCtxDestroy = cuCtxDestroy;
159+
libcu_ops.cuCtxGetCurrent = cuCtxGetCurrent;
143160
libcu_ops.cuDeviceGet = cuDeviceGet;
144161
libcu_ops.cuMemAlloc = cuMemAlloc;
145162
libcu_ops.cuMemAllocHost = cuMemAllocHost;
146163
libcu_ops.cuMemAllocManaged = cuMemAllocManaged;
147164
libcu_ops.cuMemFree = cuMemFree;
148165
libcu_ops.cuMemFreeHost = cuMemFreeHost;
149166
libcu_ops.cuMemsetD32 = cuMemsetD32;
150-
libcu_ops.cuMemcpyDtoH = cuMemcpyDtoH;
167+
libcu_ops.cuMemcpy = cuMemcpy;
168+
libcu_ops.cuPointerGetAttribute = cuPointerGetAttribute;
151169
libcu_ops.cuPointerGetAttributes = cuPointerGetAttributes;
152170

153171
return 0;
@@ -193,9 +211,10 @@ int cuda_copy(CUcontext context, CUdevice device, void *dst_ptr, void *src_ptr,
193211
(void)device;
194212

195213
int ret = 0;
196-
CUresult res = libcu_ops.cuMemcpyDtoH(dst_ptr, (CUdeviceptr)src_ptr, size);
214+
CUresult res =
215+
libcu_ops.cuMemcpy((CUdeviceptr)dst_ptr, (CUdeviceptr)src_ptr, size);
197216
if (res != CUDA_SUCCESS) {
198-
fprintf(stderr, "cuMemcpyDtoH() failed!\n");
217+
fprintf(stderr, "cuMemcpy() failed!\n");
199218
return -1;
200219
}
201220

@@ -230,6 +249,29 @@ umf_usm_memory_type_t get_mem_type(CUcontext context, void *ptr) {
230249
return UMF_MEMORY_TYPE_UNKNOWN;
231250
}
232251

252+
CUcontext get_mem_context(void *ptr) {
253+
CUcontext context;
254+
CUresult res = libcu_ops.cuPointerGetAttribute(
255+
&context, CU_POINTER_ATTRIBUTE_CONTEXT, (CUdeviceptr)ptr);
256+
if (res != CUDA_SUCCESS) {
257+
fprintf(stderr, "cuPointerGetAttribute() failed!\n");
258+
return nullptr;
259+
}
260+
261+
return context;
262+
}
263+
264+
CUcontext get_current_context() {
265+
CUcontext context;
266+
CUresult res = libcu_ops.cuCtxGetCurrent(&context);
267+
if (res != CUDA_SUCCESS) {
268+
fprintf(stderr, "cuCtxGetCurrent() failed!\n");
269+
return nullptr;
270+
}
271+
272+
return context;
273+
}
274+
233275
UTIL_ONCE_FLAG cuda_init_flag;
234276
int InitResult;
235277
void init_cuda_once() {

test/providers/cuda_helpers.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@ int cuda_copy(CUcontext context, CUdevice device, void *dst_ptr, void *src_ptr,
2626

2727
umf_usm_memory_type_t get_mem_type(CUcontext context, void *ptr);
2828

29+
CUcontext get_mem_context(void *ptr);
30+
31+
CUcontext get_current_context();
32+
2933
cuda_memory_provider_params_t
3034
create_cuda_prov_params(umf_usm_memory_type_t memory_type);
3135

test/providers/provider_cuda.cpp

Lines changed: 32 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,8 @@ using namespace umf_test;
2121

2222
class CUDAMemoryAccessor : public MemoryAccessor {
2323
public:
24-
void init(CUcontext hContext, CUdevice hDevice) {
25-
hDevice_ = hDevice;
26-
hContext_ = hContext;
27-
}
24+
CUDAMemoryAccessor(CUcontext hContext, CUdevice hDevice)
25+
: hDevice_(hDevice), hContext_(hContext) {}
2826

2927
void fill(void *ptr, size_t size, const void *pattern,
3028
size_t pattern_size) {
@@ -53,7 +51,7 @@ class CUDAMemoryAccessor : public MemoryAccessor {
5351
};
5452

5553
using CUDAProviderTestParams =
56-
std::tuple<umf_usm_memory_type_t, MemoryAccessor *>;
54+
std::tuple<cuda_memory_provider_params_t, MemoryAccessor *>;
5755

5856
struct umfCUDAProviderTest
5957
: umf_test::test,
@@ -62,23 +60,12 @@ struct umfCUDAProviderTest
6260
void SetUp() override {
6361
test::SetUp();
6462

65-
auto [memory_type, accessor] = this->GetParam();
66-
params = create_cuda_prov_params(memory_type);
63+
auto [cuda_params, accessor] = this->GetParam();
64+
params = cuda_params;
6765
memAccessor = accessor;
68-
if (memory_type == UMF_MEMORY_TYPE_DEVICE) {
69-
((CUDAMemoryAccessor *)memAccessor)
70-
->init((CUcontext)params.cuda_context_handle,
71-
params.cuda_device_handle);
72-
}
7366
}
7467

75-
void TearDown() override {
76-
if (params.cuda_context_handle) {
77-
int ret = destroy_context((CUcontext)params.cuda_context_handle);
78-
ASSERT_EQ(ret, 0);
79-
}
80-
test::TearDown();
81-
}
68+
void TearDown() override { test::TearDown(); }
8269

8370
cuda_memory_provider_params_t params;
8471
MemoryAccessor *memAccessor = nullptr;
@@ -87,6 +74,7 @@ struct umfCUDAProviderTest
8774
TEST_P(umfCUDAProviderTest, basic) {
8875
const size_t size = 1024 * 8;
8976
const uint32_t pattern = 0xAB;
77+
CUcontext expected_current_context = get_current_context();
9078

9179
// create CUDA provider
9280
umf_memory_provider_handle_t provider = nullptr;
@@ -113,6 +101,12 @@ TEST_P(umfCUDAProviderTest, basic) {
113101
// use the allocated memory - fill it with a 0xAB pattern
114102
memAccessor->fill(ptr, size, &pattern, sizeof(pattern));
115103

104+
CUcontext actual_mem_context = get_mem_context(ptr);
105+
ASSERT_EQ(actual_mem_context, (CUcontext)params.cuda_context_handle);
106+
107+
CUcontext actual_current_context = get_current_context();
108+
ASSERT_EQ(actual_current_context, expected_current_context);
109+
116110
umf_usm_memory_type_t memoryTypeActual =
117111
get_mem_type((CUcontext)params.cuda_context_handle, ptr);
118112
ASSERT_EQ(memoryTypeActual, params.memory_type);
@@ -132,6 +126,7 @@ TEST_P(umfCUDAProviderTest, basic) {
132126
}
133127

134128
TEST_P(umfCUDAProviderTest, allocInvalidSize) {
129+
CUcontext expected_current_context = get_current_context();
135130
// create CUDA provider
136131
umf_memory_provider_handle_t provider = nullptr;
137132
umf_result_t umf_result =
@@ -151,32 +146,32 @@ TEST_P(umfCUDAProviderTest, allocInvalidSize) {
151146
ASSERT_EQ(umf_result, UMF_RESULT_ERROR_INVALID_ARGUMENT);
152147
}
153148

154-
// destroy context and try to alloc some memory
155-
destroy_context((CUcontext)params.cuda_context_handle);
156-
params.cuda_context_handle = 0;
157-
umf_result = umfMemoryProviderAlloc(provider, 128, 0, &ptr);
158-
ASSERT_EQ(umf_result, UMF_RESULT_ERROR_MEMORY_PROVIDER_SPECIFIC);
159-
160-
const char *message;
161-
int32_t error;
162-
umfMemoryProviderGetLastNativeError(provider, &message, &error);
163-
ASSERT_EQ(error, CUDA_ERROR_INVALID_CONTEXT);
164-
const char *expected_message =
165-
"CUDA_ERROR_INVALID_CONTEXT - invalid device context";
166-
ASSERT_EQ(strncmp(message, expected_message, strlen(expected_message)), 0);
149+
CUcontext actual_current_context = get_current_context();
150+
ASSERT_EQ(actual_current_context, expected_current_context);
151+
152+
umfMemoryProviderDestroy(provider);
167153
}
168154

169155
// TODO add tests that mixes CUDA Memory Provider and Disjoint Pool
170156

171-
CUDAMemoryAccessor cuAccessor;
157+
cuda_memory_provider_params_t cuParams_device_memory =
158+
create_cuda_prov_params(UMF_MEMORY_TYPE_DEVICE);
159+
cuda_memory_provider_params_t cuParams_shared_memory =
160+
create_cuda_prov_params(UMF_MEMORY_TYPE_SHARED);
161+
cuda_memory_provider_params_t cuParams_host_memory =
162+
create_cuda_prov_params(UMF_MEMORY_TYPE_HOST);
163+
164+
CUDAMemoryAccessor
165+
cuAccessor((CUcontext)cuParams_device_memory.cuda_context_handle,
166+
(CUdevice)cuParams_device_memory.cuda_device_handle);
172167
HostMemoryAccessor hostAccessor;
173168

174169
INSTANTIATE_TEST_SUITE_P(
175170
umfCUDAProviderTestSuite, umfCUDAProviderTest,
176171
::testing::Values(
177-
CUDAProviderTestParams{UMF_MEMORY_TYPE_DEVICE, &cuAccessor},
178-
CUDAProviderTestParams{UMF_MEMORY_TYPE_SHARED, &hostAccessor},
179-
CUDAProviderTestParams{UMF_MEMORY_TYPE_HOST, &hostAccessor}));
172+
CUDAProviderTestParams{cuParams_device_memory, &cuAccessor},
173+
CUDAProviderTestParams{cuParams_shared_memory, &hostAccessor},
174+
CUDAProviderTestParams{cuParams_host_memory, &hostAccessor}));
180175

181176
// TODO: add IPC API
182177
GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(umfIpcTest);
@@ -185,5 +180,5 @@ INSTANTIATE_TEST_SUITE_P(umfCUDAProviderTestSuite, umfIpcTest,
185180
::testing::Values(ipcTestParams{
186181
umfProxyPoolOps(), nullptr,
187182
umfCUDAMemoryProviderOps(),
188-
&cuParams_device_memory, &l0Accessor}));
183+
&cuParams_device_memory, &cuAccessor}));
189184
*/

0 commit comments

Comments
 (0)