Skip to content

Commit d91c774

Browse files
authored
Merge pull request #1049 from bratpiorka/rrudnick_cuda_free
add set/restore context in CUDA provider free()
2 parents d3daaf6 + 4acb4e9 commit d91c774

File tree

4 files changed

+88
-5
lines changed

4 files changed

+88
-5
lines changed

src/provider/provider_cuda.c

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,14 @@ static umf_result_t cu_memory_provider_free(void *provider, void *ptr,
433433

434434
cu_memory_provider_t *cu_provider = (cu_memory_provider_t *)provider;
435435

436+
// Remember current context and set the one from the provider
437+
CUcontext restore_ctx = NULL;
438+
umf_result_t umf_result = set_context(cu_provider->context, &restore_ctx);
439+
if (umf_result != UMF_RESULT_SUCCESS) {
440+
LOG_ERR("Failed to set CUDA context, ret = %d", umf_result);
441+
return umf_result;
442+
}
443+
436444
CUresult cu_result = CUDA_SUCCESS;
437445
switch (cu_provider->memory_type) {
438446
case UMF_MEMORY_TYPE_HOST: {
@@ -451,6 +459,11 @@ static umf_result_t cu_memory_provider_free(void *provider, void *ptr,
451459
return UMF_RESULT_ERROR_UNKNOWN;
452460
}
453461

462+
umf_result = set_context(restore_ctx, &restore_ctx);
463+
if (umf_result != UMF_RESULT_SUCCESS) {
464+
LOG_ERR("Failed to restore CUDA context, ret = %d", umf_result);
465+
}
466+
454467
return cu2umf_result(cu_result);
455468
}
456469

test/providers/cuda_helpers.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (C) 2024 Intel Corporation
2+
* Copyright (C) 2024-2025 Intel Corporation
33
*
44
* Under the Apache License v2.0 with LLVM Exceptions. See LICENSE.TXT.
55
* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
@@ -251,15 +251,18 @@ int InitCUDAOps() {
251251
}
252252
#endif // USE_DLOPEN
253253

254-
static CUresult set_context(CUcontext required_ctx, CUcontext *restore_ctx) {
254+
CUresult set_context(CUcontext required_ctx, CUcontext *restore_ctx) {
255255
CUcontext current_ctx = NULL;
256256
CUresult cu_result = libcu_ops.cuCtxGetCurrent(&current_ctx);
257257
if (cu_result != CUDA_SUCCESS) {
258258
fprintf(stderr, "cuCtxGetCurrent() failed.\n");
259259
return cu_result;
260260
}
261261

262-
*restore_ctx = current_ctx;
262+
if (restore_ctx != NULL) {
263+
*restore_ctx = current_ctx;
264+
}
265+
263266
if (current_ctx != required_ctx) {
264267
cu_result = libcu_ops.cuCtxSetCurrent(required_ctx);
265268
if (cu_result != CUDA_SUCCESS) {

test/providers/cuda_helpers.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (C) 2024 Intel Corporation
2+
* Copyright (C) 2024-2025 Intel Corporation
33
*
44
* Under the Apache License v2.0 with LLVM Exceptions. See LICENSE.TXT.
55
* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
@@ -30,6 +30,8 @@ int get_cuda_device(CUdevice *device);
3030

3131
int create_context(CUdevice device, CUcontext *context);
3232

33+
CUresult set_context(CUcontext required_ctx, CUcontext *restore_ctx);
34+
3335
int destroy_context(CUcontext context);
3436

3537
int cuda_fill(CUcontext context, CUdevice device, void *ptr, size_t size,

test/providers/provider_cuda.cpp

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (C) 2024 Intel Corporation
1+
// Copyright (C) 2024-2025 Intel Corporation
22
// Under the Apache License v2.0 with LLVM Exceptions. See LICENSE.TXT.
33
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
44

@@ -315,6 +315,71 @@ TEST_P(umfCUDAProviderTest, cudaProviderNullParams) {
315315
EXPECT_EQ(res, UMF_RESULT_ERROR_INVALID_ARGUMENT);
316316
}
317317

318+
TEST_P(umfCUDAProviderTest, multiContext) {
319+
CUdevice device;
320+
int ret = get_cuda_device(&device);
321+
ASSERT_EQ(ret, 0);
322+
323+
// create two CUDA contexts and two providers
324+
CUcontext ctx1, ctx2;
325+
ret = create_context(device, &ctx1);
326+
ASSERT_EQ(ret, 0);
327+
ret = create_context(device, &ctx2);
328+
ASSERT_EQ(ret, 0);
329+
330+
cuda_params_unique_handle_t params1 =
331+
create_cuda_prov_params(ctx1, device, UMF_MEMORY_TYPE_HOST);
332+
ASSERT_NE(params1, nullptr);
333+
umf_memory_provider_handle_t provider1;
334+
umf_result_t umf_result = umfMemoryProviderCreate(
335+
umfCUDAMemoryProviderOps(), params1.get(), &provider1);
336+
ASSERT_EQ(umf_result, UMF_RESULT_SUCCESS);
337+
ASSERT_NE(provider1, nullptr);
338+
339+
cuda_params_unique_handle_t params2 =
340+
create_cuda_prov_params(ctx2, device, UMF_MEMORY_TYPE_HOST);
341+
ASSERT_NE(params2, nullptr);
342+
umf_memory_provider_handle_t provider2;
343+
umf_result = umfMemoryProviderCreate(umfCUDAMemoryProviderOps(),
344+
params2.get(), &provider2);
345+
ASSERT_EQ(umf_result, UMF_RESULT_SUCCESS);
346+
ASSERT_NE(provider2, nullptr);
347+
348+
// use the providers
349+
// allocate from 1, then from 2, then free 1, then free 2
350+
void *ptr1, *ptr2;
351+
const int size = 128;
352+
// NOTE: we use ctx1 here
353+
umf_result = umfMemoryProviderAlloc(provider1, size, 0, &ptr1);
354+
ASSERT_EQ(umf_result, UMF_RESULT_SUCCESS);
355+
ASSERT_NE(ptr1, nullptr);
356+
357+
// NOTE: we use ctx2 here
358+
umf_result = umfMemoryProviderAlloc(provider2, size, 0, &ptr2);
359+
ASSERT_EQ(umf_result, UMF_RESULT_SUCCESS);
360+
ASSERT_NE(ptr2, nullptr);
361+
362+
// even if we change the context, we should be able to free the memory
363+
ret = set_context(ctx2, NULL);
364+
ASSERT_EQ(ret, 0);
365+
// free memory from ctx1
366+
umf_result = umfMemoryProviderFree(provider1, ptr1, size);
367+
ASSERT_EQ(umf_result, UMF_RESULT_SUCCESS);
368+
369+
ret = set_context(ctx1, NULL);
370+
ASSERT_EQ(ret, 0);
371+
umf_result = umfMemoryProviderFree(provider2, ptr2, size);
372+
ASSERT_EQ(umf_result, UMF_RESULT_SUCCESS);
373+
374+
// cleanup
375+
umfMemoryProviderDestroy(provider2);
376+
umfMemoryProviderDestroy(provider1);
377+
ret = destroy_context(ctx1);
378+
ASSERT_EQ(ret, 0);
379+
ret = destroy_context(ctx2);
380+
ASSERT_EQ(ret, 0);
381+
}
382+
318383
// TODO add tests that mixes CUDA Memory Provider and Disjoint Pool
319384

320385
CUDATestHelper cudaTestHelper;

0 commit comments

Comments
 (0)