Skip to content

Commit e3d88c2

Browse files
authored
Merge pull request #2293 from yingcong-wu/yc-PR/241107-misc-minor-fix
[DeviceAsan] Serval bug fixes
2 parents 3d202c0 + cc4395b commit e3d88c2

File tree

4 files changed

+29
-15
lines changed

4 files changed

+29
-15
lines changed

source/loader/layers/sanitizer/asan/asan_ddi.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@ ur_result_t setupContext(ur_context_handle_t Context, uint32_t numDevices,
5555
bool isInstrumentedKernel(ur_kernel_handle_t hKernel) {
5656
auto hProgram = GetProgram(hKernel);
5757
auto PI = getAsanInterceptor()->getProgramInfo(hProgram);
58+
if (PI == nullptr) {
59+
return false;
60+
}
5861
return PI->isKernelInstrumented(hKernel);
5962
}
6063

@@ -290,8 +293,9 @@ __urdlllocal ur_result_t UR_APICALL urProgramRetain(
290293
UR_CALL(pfnRetain(hProgram));
291294

292295
auto ProgramInfo = getAsanInterceptor()->getProgramInfo(hProgram);
293-
UR_ASSERT(ProgramInfo != nullptr, UR_RESULT_ERROR_INVALID_VALUE);
294-
ProgramInfo->RefCount++;
296+
if (ProgramInfo != nullptr) {
297+
ProgramInfo->RefCount++;
298+
}
295299

296300
return UR_RESULT_SUCCESS;
297301
}
@@ -364,6 +368,7 @@ __urdlllocal ur_result_t UR_APICALL urProgramLink(
364368

365369
UR_CALL(pfnProgramLink(hContext, count, phPrograms, pOptions, phProgram));
366370

371+
UR_CALL(getAsanInterceptor()->insertProgram(*phProgram));
367372
UR_CALL(getAsanInterceptor()->registerProgram(*phProgram));
368373

369374
return UR_RESULT_SUCCESS;
@@ -395,6 +400,7 @@ ur_result_t UR_APICALL urProgramLinkExp(
395400
UR_CALL(pfnProgramLinkExp(hContext, numDevices, phDevices, count,
396401
phPrograms, pOptions, phProgram));
397402

403+
UR_CALL(getAsanInterceptor()->insertProgram(*phProgram));
398404
UR_CALL(getAsanInterceptor()->registerProgram(*phProgram));
399405

400406
return UR_RESULT_SUCCESS;
@@ -417,8 +423,7 @@ ur_result_t UR_APICALL urProgramRelease(
417423
UR_CALL(pfnProgramRelease(hProgram));
418424

419425
auto ProgramInfo = getAsanInterceptor()->getProgramInfo(hProgram);
420-
UR_ASSERT(ProgramInfo != nullptr, UR_RESULT_ERROR_INVALID_VALUE);
421-
if (--ProgramInfo->RefCount == 0) {
426+
if (ProgramInfo != nullptr && --ProgramInfo->RefCount == 0) {
422427
UR_CALL(getAsanInterceptor()->unregisterProgram(hProgram));
423428
UR_CALL(getAsanInterceptor()->eraseProgram(hProgram));
424429
}

source/loader/layers/sanitizer/asan/asan_interceptor.cpp

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -221,25 +221,28 @@ ur_result_t AsanInterceptor::releaseMemory(ur_context_handle_t Context,
221221
if (ReleaseList.size()) {
222222
std::scoped_lock<ur_shared_mutex> Guard(m_AllocationMapMutex);
223223
for (auto &It : ReleaseList) {
224+
auto ToFreeAllocInfo = It->second;
224225
getContext()->logger.info("Quarantine Free: {}",
225-
(void *)It->second->AllocBegin);
226+
(void *)ToFreeAllocInfo->AllocBegin);
226227

227-
ContextInfo->Stats.UpdateUSMRealFreed(AllocInfo->AllocSize,
228-
AllocInfo->getRedzoneSize());
228+
ContextInfo->Stats.UpdateUSMRealFreed(
229+
ToFreeAllocInfo->AllocSize, ToFreeAllocInfo->getRedzoneSize());
229230

230-
m_AllocationMap.erase(It);
231-
if (AllocInfo->Type == AllocType::HOST_USM) {
231+
if (ToFreeAllocInfo->Type == AllocType::HOST_USM) {
232232
for (auto &Device : ContextInfo->DeviceList) {
233233
UR_CALL(getDeviceInfo(Device)->Shadow->ReleaseShadow(
234-
AllocInfo));
234+
ToFreeAllocInfo));
235235
}
236236
} else {
237-
UR_CALL(getDeviceInfo(AllocInfo->Device)
238-
->Shadow->ReleaseShadow(AllocInfo));
237+
UR_CALL(getDeviceInfo(ToFreeAllocInfo->Device)
238+
->Shadow->ReleaseShadow(ToFreeAllocInfo));
239239
}
240240

241241
UR_CALL(getContext()->urDdiTable.USM.pfnFree(
242-
Context, (void *)(It->second->AllocBegin)));
242+
Context, (void *)(ToFreeAllocInfo->AllocBegin)));
243+
244+
// Erase it at last to avoid use-after-free.
245+
m_AllocationMap.erase(It);
243246
}
244247
}
245248
ContextInfo->Stats.UpdateUSMFreed(AllocInfo->AllocSize);
@@ -431,6 +434,7 @@ ur_result_t AsanInterceptor::registerProgram(ur_program_handle_t Program) {
431434

432435
ur_result_t AsanInterceptor::unregisterProgram(ur_program_handle_t Program) {
433436
auto ProgramInfo = getProgramInfo(Program);
437+
assert(ProgramInfo != nullptr && "unregistered program!");
434438

435439
for (auto AI : ProgramInfo->AllocInfoForGlobals) {
436440
UR_CALL(getDeviceInfo(AI->Device)->Shadow->ReleaseShadow(AI));
@@ -475,6 +479,7 @@ ur_result_t AsanInterceptor::registerSpirKernels(ur_program_handle_t Program) {
475479
}
476480

477481
auto PI = getProgramInfo(Program);
482+
assert(PI != nullptr && "unregistered program!");
478483
for (const auto &SKI : SKInfo) {
479484
if (SKI.Size == 0) {
480485
continue;
@@ -511,6 +516,7 @@ AsanInterceptor::registerDeviceGlobals(ur_program_handle_t Program) {
511516
auto Context = GetContext(Program);
512517
auto ContextInfo = getContextInfo(Context);
513518
auto ProgramInfo = getProgramInfo(Program);
519+
assert(ProgramInfo != nullptr && "unregistered program!");
514520

515521
for (auto Device : Devices) {
516522
ManagedQueue Queue(Context, Device);

source/loader/layers/sanitizer/asan/asan_interceptor.hpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -334,8 +334,10 @@ class AsanInterceptor {
334334

335335
std::shared_ptr<ProgramInfo> getProgramInfo(ur_program_handle_t Program) {
336336
std::shared_lock<ur_shared_mutex> Guard(m_ProgramMapMutex);
337-
assert(m_ProgramMap.find(Program) != m_ProgramMap.end());
338-
return m_ProgramMap[Program];
337+
if (m_ProgramMap.find(Program) != m_ProgramMap.end()) {
338+
return m_ProgramMap[Program];
339+
}
340+
return nullptr;
339341
}
340342

341343
std::shared_ptr<KernelInfo> getKernelInfo(ur_kernel_handle_t Kernel) {

source/loader/layers/sanitizer/asan/asan_shadow.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,7 @@ ur_result_t ShadowMemoryGPU::ReleaseShadow(std::shared_ptr<AllocInfo> AI) {
250250
getContext()->logger.debug("urVirtualMemUnmap: {} ~ {}",
251251
(void *)MappedPtr,
252252
(void *)(MappedPtr + PageSize - 1));
253+
VirtualMemMaps.erase(MappedPtr);
253254
}
254255
}
255256

0 commit comments

Comments
 (0)