Skip to content

Commit 359333b

Browse files
committed
Fix CUDA provider to use proper context
1 parent 5bf1b5e commit 359333b

File tree

1 file changed

+46
-2
lines changed

1 file changed

+46
-2
lines changed

src/provider/provider_cuda.c

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ typedef struct cu_ops_t {
5151

5252
CUresult (*cuGetErrorName)(CUresult error, const char **pStr);
5353
CUresult (*cuGetErrorString)(CUresult error, const char **pStr);
54+
CUresult (*cuCtxGetCurrent)(CUcontext *pctx);
55+
CUresult (*cuCtxSetCurrent)(CUcontext ctx);
5456
} cu_ops_t;
5557

5658
static cu_ops_t g_cu_ops;
@@ -117,11 +119,16 @@ static void init_cu_global_state(void) {
117119
utils_get_symbol_addr(0, "cuGetErrorName", lib_name);
118120
*(void **)&g_cu_ops.cuGetErrorString =
119121
utils_get_symbol_addr(0, "cuGetErrorString", lib_name);
122+
*(void **)&g_cu_ops.cuCtxGetCurrent =
123+
utils_get_symbol_addr(0, "cuCtxGetCurrent", lib_name);
124+
*(void **)&g_cu_ops.cuCtxSetCurrent =
125+
utils_get_symbol_addr(0, "cuCtxSetCurrent", lib_name);
120126

121127
if (!g_cu_ops.cuMemGetAllocationGranularity || !g_cu_ops.cuMemAlloc ||
122128
!g_cu_ops.cuMemAllocHost || !g_cu_ops.cuMemAllocManaged ||
123129
!g_cu_ops.cuMemFree || !g_cu_ops.cuMemFreeHost ||
124-
!g_cu_ops.cuGetErrorName || !g_cu_ops.cuGetErrorString) {
130+
!g_cu_ops.cuGetErrorName || !g_cu_ops.cuGetErrorString ||
131+
!g_cu_ops.cuCtxGetCurrent || !g_cu_ops.cuCtxSetCurrent) {
125132
LOG_ERR("Required CUDA symbols not found.");
126133
Init_cu_global_state_failed = true;
127134
}
@@ -190,6 +197,31 @@ static void cu_memory_provider_finalize(void *provider) {
190197
umf_ba_global_free(provider);
191198
}
192199

200+
/*
201+
* This function is used by the CUDA provider to make sure that
202+
* the required context is set. If the current context is
203+
* not the required one, it will be saved in restore_ctx.
204+
*/
205+
static inline umf_result_t set_context(CUcontext required_ctx,
206+
CUcontext *restore_ctx) {
207+
CUcontext current_ctx = NULL;
208+
CUresult cu_result = g_cu_ops.cuCtxGetCurrent(&current_ctx);
209+
if (cu_result != CUDA_SUCCESS) {
210+
LOG_ERR("cuCtxGetCurrent() failed.");
211+
return cu2umf_result(cu_result);
212+
}
213+
*restore_ctx = current_ctx;
214+
if (current_ctx != required_ctx) {
215+
cu_result = g_cu_ops.cuCtxSetCurrent(required_ctx);
216+
if (cu_result != CUDA_SUCCESS) {
217+
LOG_ERR("cuCtxSetCurrent() failed.");
218+
return cu2umf_result(cu_result);
219+
}
220+
}
221+
222+
return UMF_RESULT_SUCCESS;
223+
}
224+
193225
static umf_result_t cu_memory_provider_alloc(void *provider, size_t size,
194226
size_t alignment,
195227
void **resultPtr) {
@@ -205,6 +237,13 @@ static umf_result_t cu_memory_provider_alloc(void *provider, size_t size,
205237
return UMF_RESULT_ERROR_NOT_SUPPORTED;
206238
}
207239

240+
// Remember current context and set the one from the provider
241+
CUcontext restore_ctx = NULL;
242+
umf_result_t umf_result = set_context(cu_provider->context, &restore_ctx);
243+
if (umf_result != UMF_RESULT_SUCCESS) {
244+
return umf_result;
245+
}
246+
208247
CUresult cu_result = CUDA_SUCCESS;
209248
switch (cu_provider->memory_type) {
210249
case UMF_MEMORY_TYPE_HOST: {
@@ -224,16 +263,21 @@ static umf_result_t cu_memory_provider_alloc(void *provider, size_t size,
224263
// this shouldn't happen as we check the memory_type settings during
225264
// the initialization
226265
LOG_ERR("unsupported USM memory type");
227-
return UMF_RESULT_ERROR_UNKNOWN;
266+
assert(false);
228267
}
229268

230269
// check the alignment
231270
if (alignment > 0 && ((uintptr_t)(*resultPtr) % alignment) != 0) {
232271
cu_memory_provider_free(provider, *resultPtr, size);
233272
LOG_ERR("unsupported alignment size");
273+
set_context(restore_ctx, &restore_ctx);
234274
return UMF_RESULT_ERROR_INVALID_ALIGNMENT;
235275
}
236276

277+
umf_result = set_context(restore_ctx, &restore_ctx);
278+
if (umf_result != UMF_RESULT_SUCCESS) {
279+
return umf_result;
280+
}
237281
return cu2umf_result(cu_result);
238282
}
239283

0 commit comments

Comments
 (0)