Skip to content

Commit a6511d5

Browse files
committed
Resolve cudaSetDevice context creation
Context should only be created on target device
1 parent 15935e5 commit a6511d5

File tree

3 files changed

+45
-19
lines changed

3 files changed

+45
-19
lines changed

cuda/_lib/ccudart/ccudart.pyx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1059,7 +1059,7 @@ cdef cudaError_t _cudaGetDevice(int* device) nogil except ?cudaErrorCallRequires
10591059

10601060
cdef cudaError_t _cudaSetDevice(int device) nogil except ?cudaErrorCallRequiresNewerDriver:
10611061
cdef cudaError_t err
1062-
err = m_global.lazyInit()
1062+
err = m_global.lazyInitGlobal()
10631063
if err != cudaSuccess:
10641064
return err
10651065
if device < 0 or device >= m_global._numDevices:

cuda/_lib/ccudart/utils.pxd

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ cimport cuda._cuda.ccuda as ccuda
1313

1414
cdef class cudaPythonGlobal:
1515
cdef bint _cudaPythonInit
16+
cdef bint _cudaPythonGlobalInit
1617
cdef int _numDevices
1718
cdef ccuda.CUdevice* _driverDevice
1819
cdef ccuda.CUcontext* _driverContext
@@ -22,6 +23,7 @@ cdef class cudaPythonGlobal:
2223
cdef int CUDART_VERSION
2324

2425
cdef cudaError_t lazyInit(self) nogil
26+
cdef cudaError_t lazyInitGlobal(self) nogil
2527
cdef cudaError_t lazyInitDevice(self, int deviceOrdinal) nogil
2628

2729
cdef cudaPythonGlobal globalGetInstance()

cuda/_lib/ccudart/utils.pyx

Lines changed: 42 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ ctypedef cudaStreamCallbackData_st cudaStreamCallbackData
2929
cdef class cudaPythonGlobal:
3030
def __cinit__(self):
3131
self._cudaPythonInit = False
32+
self._cudaPythonGlobalInit = False
3233
self._numDevices = 0
3334
self._driverDevice = NULL
3435
self._driverContext = NULL
@@ -54,25 +55,10 @@ cdef class cudaPythonGlobal:
5455
if self._numDevices > 0:
5556
ccuda._cuCtxSetCurrent(self._driverContext[0])
5657
return cudaSuccess
57-
err = ccuda._cuInit(0)
58-
if err != ccuda.cudaError_enum.CUDA_SUCCESS:
59-
return <cudaError_t>err
60-
err = ccuda._cuDeviceGetCount(&self._numDevices)
61-
if err != ccuda.cudaError_enum.CUDA_SUCCESS:
62-
return cudaErrorInitializationError
6358

64-
self._driverDevice = <ccuda.CUdevice *>calloc(self._numDevices, sizeof(ccuda.CUdevice))
65-
if self._driverDevice == NULL:
66-
return cudaErrorMemoryAllocation
67-
self._driverContext = <ccuda.CUcontext *>calloc(self._numDevices, sizeof(ccuda.CUcontext))
68-
if self._driverContext == NULL:
69-
return cudaErrorMemoryAllocation
70-
self._deviceProperties = <cudaDeviceProp *>calloc(self._numDevices, sizeof(cudaDeviceProp))
71-
if self._deviceProperties == NULL:
72-
return cudaErrorMemoryAllocation
73-
self._deviceInit = <bool *>calloc(self._numDevices, sizeof(bool))
74-
if self._deviceInit == NULL:
75-
return cudaErrorMemoryAllocation
59+
err_rt = self.lazyInitGlobal()
60+
if err_rt != cudaSuccess:
61+
return err_rt
7662

7763
err_rt = self.lazyInitDevice(0)
7864
if err_rt != cudaSuccess:
@@ -83,6 +69,44 @@ cdef class cudaPythonGlobal:
8369
return cudaErrorInitializationError
8470
self._cudaPythonInit = True
8571

72+
cdef cudaError_t lazyInitGlobal(self) nogil:
73+
cdef cudaError_t err = cudaSuccess
74+
if self._cudaPythonGlobalInit:
75+
return err
76+
77+
err = <cudaError_t>ccuda._cuInit(0)
78+
if err != cudaSuccess:
79+
return err
80+
err = <cudaError_t>ccuda._cuDeviceGetCount(&self._numDevices)
81+
if err != cudaSuccess:
82+
return cudaErrorInitializationError
83+
84+
self._driverDevice = <ccuda.CUdevice *>calloc(self._numDevices, sizeof(ccuda.CUdevice))
85+
if err != cudaSuccess or self._driverDevice == NULL:
86+
err = cudaErrorMemoryAllocation
87+
self._driverContext = <ccuda.CUcontext *>calloc(self._numDevices, sizeof(ccuda.CUcontext))
88+
if err != cudaSuccess or self._driverContext == NULL:
89+
err = cudaErrorMemoryAllocation
90+
self._deviceProperties = <cudaDeviceProp *>calloc(self._numDevices, sizeof(cudaDeviceProp))
91+
if err != cudaSuccess or self._deviceProperties == NULL:
92+
err = cudaErrorMemoryAllocation
93+
self._deviceInit = <bool *>calloc(self._numDevices, sizeof(bool))
94+
if err != cudaSuccess or self._deviceInit == NULL:
95+
err = cudaErrorMemoryAllocation
96+
97+
if err != cudaSuccess:
98+
if self._deviceInit is not NULL:
99+
free(self._deviceInit)
100+
if self._deviceProperties is not NULL:
101+
free(self._deviceProperties)
102+
if self._driverContext is not NULL:
103+
free(self._driverContext)
104+
if self._driverDevice is not NULL:
105+
free(self._driverDevice)
106+
else:
107+
self._cudaPythonGlobalInit = True
108+
return err
109+
86110
cdef cudaError_t lazyInitDevice(self, int deviceOrdinal) nogil:
87111
if self._deviceInit[deviceOrdinal]:
88112
return cudaSuccess

0 commit comments

Comments
 (0)