Skip to content

Commit e86f38d

Browse files
aarongreigcallumfare
authored andcommitted
[SYCL][CUDA] Add a few extra checks to the cuda UR program implementation.
1 parent abe6aa0 commit e86f38d

File tree

1 file changed

+14
-3
lines changed
  • sycl/plugins/unified_runtime/ur/adapters/cuda

1 file changed

+14
-3
lines changed

sycl/plugins/unified_runtime/ur/adapters/cuda/program.cpp

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -371,7 +371,15 @@ urProgramRelease(ur_program_handle_t program) {
371371
try {
372372
ScopedContext active(program->get_context());
373373
auto cuModule = program->get();
374-
result = UR_CHECK_ERROR(cuModuleUnload(cuModule));
374+
// "0" is a valid handle for a cuModule, so the best way to check if we
375+
// actually loaded a module and need to unload it is to look at the build
376+
// status.
377+
if (program->buildStatus_ == UR_PROGRAM_BUILD_STATUS_SUCCESS) {
378+
result = UR_CHECK_ERROR(cuModuleUnload(cuModule));
379+
} else if(program->buildStatus_ == UR_PROGRAM_BUILD_STATUS_NONE) {
380+
// Nothing to free.
381+
result = UR_RESULT_SUCCESS;
382+
}
375383
} catch (...) {
376384
result = UR_RESULT_ERROR_OUT_OF_RESOURCES;
377385
}
@@ -391,6 +399,7 @@ urProgramRelease(ur_program_handle_t program) {
391399
UR_APIEXPORT ur_result_t UR_APICALL urProgramGetNativeHandle(
392400
ur_program_handle_t program, ur_native_handle_t *nativeHandle) {
393401
UR_ASSERT(program, UR_RESULT_ERROR_INVALID_NULL_HANDLE);
402+
UR_ASSERT(nativeHandle, UR_RESULT_ERROR_INVALID_NULL_POINTER);
394403
*nativeHandle = reinterpret_cast<ur_native_handle_t>(program->get());
395404
return UR_RESULT_SUCCESS;
396405
}
@@ -417,8 +426,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramCreateWithBinary(
417426
std::unique_ptr<ur_program_handle_t_> retProgram{
418427
new ur_program_handle_t_{hContext}};
419428

420-
retError =
421-
retProgram->set_metadata(pProperties->pMetadatas, pProperties->count);
429+
if (pProperties && pProperties->pMetadatas) {
430+
retError =
431+
retProgram->set_metadata(pProperties->pMetadatas, pProperties->count);
432+
}
422433
UR_ASSERT(retError == UR_RESULT_SUCCESS, retError);
423434

424435
auto pBinary_string = reinterpret_cast<const char *>(pBinary);

0 commit comments

Comments
 (0)