Skip to content

Commit 2cc0206

Browse files
committed
Fix cuda_copy function.
1 parent b8088be commit 2cc0206

File tree

1 file changed

+15
-0
lines changed

1 file changed

+15
-0
lines changed

test/providers/cuda_helpers.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ struct libcu_ops {
3333
CUresult (*cuPointerGetAttributes)(unsigned int numAttributes,
3434
CUpointer_attribute *attributes,
3535
void **data, CUdeviceptr ptr);
36+
CUresult (*cuStreamSynchronize)(CUstream hStream);
3637
} libcu_ops;
3738

3839
#if USE_DLOPEN
@@ -145,6 +146,13 @@ int InitCUDAOps() {
145146
lib_name);
146147
return -1;
147148
}
149+
*(void **)&libcu_ops.cuStreamSynchronize = utils_get_symbol_addr(
150+
cuDlHandle.get(), "cuStreamSynchronize", lib_name);
151+
if (libcu_ops.cuStreamSynchronize == nullptr) {
152+
fprintf(stderr, "cuStreamSynchronize symbol not found in %s\n",
153+
lib_name);
154+
return -1;
155+
}
148156

149157
return 0;
150158
}
@@ -167,6 +175,7 @@ int InitCUDAOps() {
167175
libcu_ops.cuMemcpy = cuMemcpy;
168176
libcu_ops.cuPointerGetAttribute = cuPointerGetAttribute;
169177
libcu_ops.cuPointerGetAttributes = cuPointerGetAttributes;
178+
libcu_ops.cuStreamSynchronize = cuStreamSynchronize;
170179

171180
return 0;
172181
}
@@ -218,6 +227,12 @@ int cuda_copy(CUcontext context, CUdevice device, void *dst_ptr, void *src_ptr,
218227
return -1;
219228
}
220229

230+
res = libcu_ops.cuStreamSynchronize(0);
231+
if (res != CUDA_SUCCESS) {
232+
fprintf(stderr, "cuStreamSynchronize() failed!\n");
233+
return -1;
234+
}
235+
221236
return ret;
222237
}
223238

0 commit comments

Comments
 (0)