@@ -51,6 +51,8 @@ typedef struct cu_ops_t {
51
51
52
52
CUresult (* cuGetErrorName )(CUresult error , const char * * pStr );
53
53
CUresult (* cuGetErrorString )(CUresult error , const char * * pStr );
54
+ CUresult (* cuCtxGetCurrent )(CUcontext * pctx );
55
+ CUresult (* cuCtxSetCurrent )(CUcontext ctx );
54
56
} cu_ops_t ;
55
57
56
58
static cu_ops_t g_cu_ops ;
@@ -117,11 +119,16 @@ static void init_cu_global_state(void) {
117
119
utils_get_symbol_addr (0 , "cuGetErrorName" , lib_name );
118
120
* (void * * )& g_cu_ops .cuGetErrorString =
119
121
utils_get_symbol_addr (0 , "cuGetErrorString" , lib_name );
122
+ * (void * * )& g_cu_ops .cuCtxGetCurrent =
123
+ utils_get_symbol_addr (0 , "cuCtxGetCurrent" , lib_name );
124
+ * (void * * )& g_cu_ops .cuCtxSetCurrent =
125
+ utils_get_symbol_addr (0 , "cuCtxSetCurrent" , lib_name );
120
126
121
127
if (!g_cu_ops .cuMemGetAllocationGranularity || !g_cu_ops .cuMemAlloc ||
122
128
!g_cu_ops .cuMemAllocHost || !g_cu_ops .cuMemAllocManaged ||
123
129
!g_cu_ops .cuMemFree || !g_cu_ops .cuMemFreeHost ||
124
- !g_cu_ops .cuGetErrorName || !g_cu_ops .cuGetErrorString ) {
130
+ !g_cu_ops .cuGetErrorName || !g_cu_ops .cuGetErrorString ||
131
+ !g_cu_ops .cuCtxGetCurrent || !g_cu_ops .cuCtxSetCurrent ) {
125
132
LOG_ERR ("Required CUDA symbols not found." );
126
133
Init_cu_global_state_failed = true;
127
134
}
@@ -190,6 +197,31 @@ static void cu_memory_provider_finalize(void *provider) {
190
197
umf_ba_global_free (provider );
191
198
}
192
199
200
+ /*
201
+ * This function is used by the CUDA provider to make sure that
202
+ * the required context is set. If the current context is
203
+ * not the required one, it will be saved in restore_ctx.
204
+ */
205
+ static inline umf_result_t set_context (CUcontext required_ctx ,
206
+ CUcontext * restore_ctx ) {
207
+ CUcontext current_ctx = NULL ;
208
+ CUresult cu_result = g_cu_ops .cuCtxGetCurrent (& current_ctx );
209
+ if (cu_result != CUDA_SUCCESS ) {
210
+ LOG_ERR ("cuCtxGetCurrent() failed." );
211
+ return cu2umf_result (cu_result );
212
+ }
213
+ * restore_ctx = current_ctx ;
214
+ if (current_ctx != required_ctx ) {
215
+ cu_result = g_cu_ops .cuCtxSetCurrent (required_ctx );
216
+ if (cu_result != CUDA_SUCCESS ) {
217
+ LOG_ERR ("cuCtxSetCurrent() failed." );
218
+ return cu2umf_result (cu_result );
219
+ }
220
+ }
221
+
222
+ return UMF_RESULT_SUCCESS ;
223
+ }
224
+
193
225
static umf_result_t cu_memory_provider_alloc (void * provider , size_t size ,
194
226
size_t alignment ,
195
227
void * * resultPtr ) {
@@ -205,6 +237,13 @@ static umf_result_t cu_memory_provider_alloc(void *provider, size_t size,
205
237
return UMF_RESULT_ERROR_NOT_SUPPORTED ;
206
238
}
207
239
240
+ // Remember current context and set the one from the provider
241
+ CUcontext restore_ctx = NULL ;
242
+ umf_result_t umf_result = set_context (cu_provider -> context , & restore_ctx );
243
+ if (umf_result != UMF_RESULT_SUCCESS ) {
244
+ return umf_result ;
245
+ }
246
+
208
247
CUresult cu_result = CUDA_SUCCESS ;
209
248
switch (cu_provider -> memory_type ) {
210
249
case UMF_MEMORY_TYPE_HOST : {
@@ -224,16 +263,21 @@ static umf_result_t cu_memory_provider_alloc(void *provider, size_t size,
224
263
// this shouldn't happen as we check the memory_type settings during
225
264
// the initialization
226
265
LOG_ERR ("unsupported USM memory type" );
227
- return UMF_RESULT_ERROR_UNKNOWN ;
266
+ assert (false) ;
228
267
}
229
268
230
269
// check the alignment
231
270
if (alignment > 0 && ((uintptr_t )(* resultPtr ) % alignment ) != 0 ) {
232
271
cu_memory_provider_free (provider , * resultPtr , size );
233
272
LOG_ERR ("unsupported alignment size" );
273
+ set_context (restore_ctx , & restore_ctx );
234
274
return UMF_RESULT_ERROR_INVALID_ALIGNMENT ;
235
275
}
236
276
277
+ umf_result = set_context (restore_ctx , & restore_ctx );
278
+ if (umf_result != UMF_RESULT_SUCCESS ) {
279
+ return umf_result ;
280
+ }
237
281
return cu2umf_result (cu_result );
238
282
}
239
283
0 commit comments