Skip to content

Commit 4713aeb

Browse files
zjin-lcfJin Zsteffenlarsen
authored
[SYCL][CUDA] Set the device primary context for the cuMemGetInfo call (#7906)
This PR tries to fix the `Cuda API error detected: cuMemGetInfo_v2 returned (0xc9)` in #5713. Thank you for your review. --------- Co-authored-by: Jin Z <[email protected]> Co-authored-by: Steffen Larsen <[email protected]>
1 parent 18f2cda commit 4713aeb

File tree

1 file changed

+16
-0
lines changed

1 file changed

+16
-0
lines changed

sycl/plugins/cuda/pi_cuda.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1931,11 +1931,27 @@ pi_result cuda_piDeviceGetInfo(pi_device device, pi_device_info param_name,
19311931
}
19321932

19331933
case PI_EXT_INTEL_DEVICE_INFO_FREE_MEMORY: {
1934+
// Check the device of the currently set context uses the same device.
1935+
// CUDA_ERROR_INVALID_CONTEXT signifies the absence of an active context.
1936+
CUdevice current_ctx_device;
1937+
CUresult current_ctx_device_ret = cuCtxGetDevice(&current_ctx_device);
1938+
if (current_ctx_device_ret != CUDA_ERROR_INVALID_CONTEXT)
1939+
PI_CHECK_ERROR(current_ctx_device_ret);
1940+
bool need_primary_ctx = current_ctx_device_ret == CUDA_ERROR_INVALID_CONTEXT ||
1941+
current_ctx_device != device->get();
1942+
if (need_primary_ctx) {
1943+
// Use the primary context for the device if no context with the device is set.
1944+
CUcontext primary_context;
1945+
PI_CHECK_ERROR(cuDevicePrimaryCtxRetain(&primary_context, device->get()));
1946+
PI_CHECK_ERROR(cuCtxSetCurrent(primary_context));
1947+
}
19341948
size_t FreeMemory = 0;
19351949
size_t TotalMemory = 0;
19361950
sycl::detail::pi::assertion(cuMemGetInfo(&FreeMemory, &TotalMemory) ==
19371951
CUDA_SUCCESS,
19381952
"failed cuMemGetInfo() API.");
1953+
if (need_primary_ctx)
1954+
PI_CHECK_ERROR(cuDevicePrimaryCtxRelease(device->get()));
19391955
return getInfo(param_value_size, param_value, param_value_size_ret,
19401956
FreeMemory);
19411957
}

0 commit comments

Comments
 (0)