Skip to content

Commit 4022e19

Browse files
committed
add getOrCreateKernelInfo
1 parent 667d4d1 commit 4022e19

File tree

3 files changed

+38
-63
lines changed

3 files changed

+38
-63
lines changed

source/loader/layers/sanitizer/asan/asan_ddi.cpp

Lines changed: 17 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1335,28 +1335,6 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueMemUnmap(
13351335
return UR_RESULT_SUCCESS;
13361336
}
13371337

1338-
///////////////////////////////////////////////////////////////////////////////
1339-
/// @brief Intercept function for urKernelCreate
1340-
__urdlllocal ur_result_t UR_APICALL urKernelCreate(
1341-
ur_program_handle_t hProgram, ///< [in] handle of the program instance
1342-
const char *pKernelName, ///< [in] pointer to null-terminated string.
1343-
ur_kernel_handle_t
1344-
*phKernel ///< [out] pointer to handle of kernel object created.
1345-
) {
1346-
auto pfnCreate = getContext()->urDdiTable.Kernel.pfnCreate;
1347-
1348-
if (nullptr == pfnCreate) {
1349-
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
1350-
}
1351-
1352-
getContext()->logger.debug("==== urKernelCreate");
1353-
1354-
UR_CALL(pfnCreate(hProgram, pKernelName, phKernel));
1355-
UR_CALL(getAsanInterceptor()->insertKernel(*phKernel));
1356-
1357-
return UR_RESULT_SUCCESS;
1358-
}
1359-
13601338
///////////////////////////////////////////////////////////////////////////////
13611339
/// @brief Intercept function for urKernelRetain
13621340
__urdlllocal ur_result_t UR_APICALL urKernelRetain(
@@ -1372,8 +1350,8 @@ __urdlllocal ur_result_t UR_APICALL urKernelRetain(
13721350

13731351
UR_CALL(pfnRetain(hKernel));
13741352

1375-
auto KernelInfo = getAsanInterceptor()->getKernelInfo(hKernel);
1376-
KernelInfo->RefCount++;
1353+
auto &KernelInfo = getAsanInterceptor()->getOrCreateKernelInfo(hKernel);
1354+
KernelInfo.RefCount++;
13771355

13781356
return UR_RESULT_SUCCESS;
13791357
}
@@ -1392,9 +1370,9 @@ __urdlllocal ur_result_t urKernelRelease(
13921370
getContext()->logger.debug("==== urKernelRelease");
13931371
UR_CALL(pfnRelease(hKernel));
13941372

1395-
auto KernelInfo = getAsanInterceptor()->getKernelInfo(hKernel);
1396-
if (--KernelInfo->RefCount == 0) {
1397-
UR_CALL(getAsanInterceptor()->eraseKernel(hKernel));
1373+
auto &KernelInfo = getAsanInterceptor()->getOrCreateKernelInfo(hKernel);
1374+
if (--KernelInfo.RefCount == 0) {
1375+
UR_CALL(getAsanInterceptor()->eraseKernelInfo(hKernel));
13981376
}
13991377

14001378
return UR_RESULT_SUCCESS;
@@ -1423,9 +1401,9 @@ __urdlllocal ur_result_t UR_APICALL urKernelSetArgValue(
14231401
if (argSize == sizeof(ur_mem_handle_t) &&
14241402
(MemBuffer = getAsanInterceptor()->getMemBuffer(
14251403
*ur_cast<const ur_mem_handle_t *>(pArgValue)))) {
1426-
auto KernelInfo = getAsanInterceptor()->getKernelInfo(hKernel);
1427-
std::scoped_lock<ur_shared_mutex> Guard(KernelInfo->Mutex);
1428-
KernelInfo->BufferArgs[argIndex] = std::move(MemBuffer);
1404+
auto &KernelInfo = getAsanInterceptor()->getOrCreateKernelInfo(hKernel);
1405+
std::scoped_lock<ur_shared_mutex> Guard(KernelInfo.Mutex);
1406+
KernelInfo.BufferArgs[argIndex] = std::move(MemBuffer);
14291407
} else {
14301408
UR_CALL(
14311409
pfnSetArgValue(hKernel, argIndex, argSize, pProperties, pArgValue));
@@ -1453,9 +1431,9 @@ __urdlllocal ur_result_t UR_APICALL urKernelSetArgMemObj(
14531431

14541432
std::shared_ptr<MemBuffer> MemBuffer;
14551433
if ((MemBuffer = getAsanInterceptor()->getMemBuffer(hArgValue))) {
1456-
auto KernelInfo = getAsanInterceptor()->getKernelInfo(hKernel);
1457-
std::scoped_lock<ur_shared_mutex> Guard(KernelInfo->Mutex);
1458-
KernelInfo->BufferArgs[argIndex] = std::move(MemBuffer);
1434+
auto &KernelInfo = getAsanInterceptor()->getOrCreateKernelInfo(hKernel);
1435+
std::scoped_lock<ur_shared_mutex> Guard(KernelInfo.Mutex);
1436+
KernelInfo.BufferArgs[argIndex] = std::move(MemBuffer);
14591437
} else {
14601438
UR_CALL(pfnSetArgMemObj(hKernel, argIndex, pProperties, hArgValue));
14611439
}
@@ -1484,12 +1462,12 @@ __urdlllocal ur_result_t UR_APICALL urKernelSetArgLocal(
14841462
argSize);
14851463

14861464
{
1487-
auto KI = getAsanInterceptor()->getKernelInfo(hKernel);
1488-
std::scoped_lock<ur_shared_mutex> Guard(KI->Mutex);
1465+
auto &KI = getAsanInterceptor()->getOrCreateKernelInfo(hKernel);
1466+
std::scoped_lock<ur_shared_mutex> Guard(KI.Mutex);
14891467
// TODO: get local variable alignment
14901468
auto argSizeWithRZ = GetSizeAndRedzoneSizeForLocal(
14911469
argSize, ASAN_SHADOW_GRANULARITY, ASAN_SHADOW_GRANULARITY);
1492-
KI->LocalArgs[argIndex] = LocalArgsInfo{argSize, argSizeWithRZ};
1470+
KI.LocalArgs[argIndex] = LocalArgsInfo{argSize, argSizeWithRZ};
14931471
argSize = argSizeWithRZ;
14941472
}
14951473

@@ -1522,9 +1500,9 @@ __urdlllocal ur_result_t UR_APICALL urKernelSetArgPointer(
15221500

15231501
std::shared_ptr<KernelInfo> KI;
15241502
if (getAsanInterceptor()->getOptions().DetectKernelArguments) {
1525-
auto KI = getAsanInterceptor()->getKernelInfo(hKernel);
1526-
std::scoped_lock<ur_shared_mutex> Guard(KI->Mutex);
1527-
KI->PointerArgs[argIndex] = {pArgValue, GetCurrentBacktrace()};
1503+
auto &KI = getAsanInterceptor()->getOrCreateKernelInfo(hKernel);
1504+
std::scoped_lock<ur_shared_mutex> Guard(KI.Mutex);
1505+
KI.PointerArgs[argIndex] = {pArgValue, GetCurrentBacktrace()};
15281506
}
15291507

15301508
ur_result_t result =
@@ -1708,7 +1686,6 @@ __urdlllocal ur_result_t UR_APICALL urGetKernelProcAddrTable(
17081686

17091687
ur_result_t result = UR_RESULT_SUCCESS;
17101688

1711-
pDdiTable->pfnCreate = ur_sanitizer_layer::asan::urKernelCreate;
17121689
pDdiTable->pfnRetain = ur_sanitizer_layer::asan::urKernelRetain;
17131690
pDdiTable->pfnRelease = ur_sanitizer_layer::asan::urKernelRelease;
17141691
pDdiTable->pfnSetArgValue = ur_sanitizer_layer::asan::urKernelSetArgValue;

source/loader/layers/sanitizer/asan/asan_interceptor.cpp

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -639,22 +639,26 @@ ur_result_t AsanInterceptor::eraseProgram(ur_program_handle_t Program) {
639639
return UR_RESULT_SUCCESS;
640640
}
641641

642-
ur_result_t AsanInterceptor::insertKernel(ur_kernel_handle_t Kernel) {
643-
std::scoped_lock<ur_shared_mutex> Guard(m_KernelMapMutex);
644-
if (m_KernelMap.find(Kernel) != m_KernelMap.end()) {
645-
return UR_RESULT_SUCCESS;
642+
KernelInfo &AsanInterceptor::getOrCreateKernelInfo(ur_kernel_handle_t Kernel) {
643+
{
644+
std::shared_lock<ur_shared_mutex> Guard(m_KernelMapMutex);
645+
if (m_KernelMap.find(Kernel) != m_KernelMap.end()) {
646+
return *m_KernelMap[Kernel].get();
647+
}
646648
}
647649

650+
// Create new KernelInfo
648651
auto hProgram = GetProgram(Kernel);
649652
auto PI = getAsanInterceptor()->getProgramInfo(hProgram);
650653
bool IsInstrumented = PI->isKernelInstrumented(Kernel);
651654

655+
std::scoped_lock<ur_shared_mutex> Guard(m_KernelMapMutex);
652656
m_KernelMap.emplace(Kernel,
653-
std::make_shared<KernelInfo>(Kernel, IsInstrumented));
654-
return UR_RESULT_SUCCESS;
657+
std::make_unique<KernelInfo>(Kernel, IsInstrumented));
658+
return *m_KernelMap[Kernel].get();
655659
}
656660

657-
ur_result_t AsanInterceptor::eraseKernel(ur_kernel_handle_t Kernel) {
661+
ur_result_t AsanInterceptor::eraseKernelInfo(ur_kernel_handle_t Kernel) {
658662
std::scoped_lock<ur_shared_mutex> Guard(m_KernelMapMutex);
659663
assert(m_KernelMap.find(Kernel) != m_KernelMap.end());
660664
m_KernelMap.erase(Kernel);
@@ -691,7 +695,7 @@ ur_result_t AsanInterceptor::prepareLaunch(
691695
std::shared_ptr<ContextInfo> &ContextInfo,
692696
std::shared_ptr<DeviceInfo> &DeviceInfo, ur_queue_handle_t Queue,
693697
ur_kernel_handle_t Kernel, LaunchInfo &LaunchInfo) {
694-
auto KernelInfo = getKernelInfo(Kernel);
698+
auto &KernelInfo = getOrCreateKernelInfo(Kernel);
695699

696700
auto ArgNums = GetKernelNumArgs(Kernel);
697701
auto LocalMemoryUsage =
@@ -703,11 +707,11 @@ ur_result_t AsanInterceptor::prepareLaunch(
703707
"KernelInfo {} (Name={}, ArgNums={}, IsInstrumented={}, "
704708
"LocalMemory={}, PrivateMemory={})",
705709
(void *)Kernel, GetKernelName(Kernel), ArgNums,
706-
KernelInfo->IsInstrumented, LocalMemoryUsage, PrivateMemoryUsage);
710+
KernelInfo.IsInstrumented, LocalMemoryUsage, PrivateMemoryUsage);
707711

708712
// Validate pointer arguments
709713
if (getOptions().DetectKernelArguments) {
710-
for (const auto &[ArgIndex, PtrPair] : KernelInfo->PointerArgs) {
714+
for (const auto &[ArgIndex, PtrPair] : KernelInfo.PointerArgs) {
711715
auto Ptr = PtrPair.first;
712716
if (Ptr == nullptr) {
713717
continue;
@@ -722,7 +726,7 @@ ur_result_t AsanInterceptor::prepareLaunch(
722726
}
723727

724728
// Set membuffer arguments
725-
for (const auto &[ArgIndex, MemBuffer] : KernelInfo->BufferArgs) {
729+
for (const auto &[ArgIndex, MemBuffer] : KernelInfo.BufferArgs) {
726730
char *ArgPointer = nullptr;
727731
UR_CALL(MemBuffer->getHandle(DeviceInfo->Handle, ArgPointer));
728732
ur_result_t URes = getContext()->urDdiTable.Kernel.pfnSetArgPointer(
@@ -735,7 +739,7 @@ ur_result_t AsanInterceptor::prepareLaunch(
735739
}
736740
}
737741

738-
if (!KernelInfo->IsInstrumented) {
742+
if (!KernelInfo.IsInstrumented) {
739743
return UR_RESULT_SUCCESS;
740744
}
741745

@@ -830,9 +834,9 @@ ur_result_t AsanInterceptor::prepareLaunch(
830834
}
831835

832836
// Write local arguments info
833-
if (!KernelInfo->LocalArgs.empty()) {
837+
if (!KernelInfo.LocalArgs.empty()) {
834838
std::vector<LocalArgsInfo> LocalArgsInfo;
835-
for (auto [ArgIndex, ArgInfo] : KernelInfo->LocalArgs) {
839+
for (auto [ArgIndex, ArgInfo] : KernelInfo.LocalArgs) {
836840
LocalArgsInfo.push_back(ArgInfo);
837841
getContext()->logger.debug(
838842
"local_args (argIndex={}, size={}, sizeWithRZ={})", ArgIndex,

source/loader/layers/sanitizer/asan/asan_interceptor.hpp

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -308,9 +308,6 @@ class AsanInterceptor {
308308
ur_result_t insertProgram(ur_program_handle_t Program);
309309
ur_result_t eraseProgram(ur_program_handle_t Program);
310310

311-
ur_result_t insertKernel(ur_kernel_handle_t Kernel);
312-
ur_result_t eraseKernel(ur_kernel_handle_t Kernel);
313-
314311
ur_result_t insertMemBuffer(std::shared_ptr<MemBuffer> MemBuffer);
315312
ur_result_t eraseMemBuffer(ur_mem_handle_t MemHandle);
316313
std::shared_ptr<MemBuffer> getMemBuffer(ur_mem_handle_t MemHandle);
@@ -350,11 +347,8 @@ class AsanInterceptor {
350347
return nullptr;
351348
}
352349

353-
std::shared_ptr<KernelInfo> getKernelInfo(ur_kernel_handle_t Kernel) {
354-
std::shared_lock<ur_shared_mutex> Guard(m_KernelMapMutex);
355-
assert(m_KernelMap.find(Kernel) != m_KernelMap.end());
356-
return m_KernelMap[Kernel];
357-
}
350+
KernelInfo &getOrCreateKernelInfo(ur_kernel_handle_t Kernel);
351+
ur_result_t eraseKernelInfo(ur_kernel_handle_t Kernel);
358352

359353
const AsanOptions &getOptions() { return m_Options; }
360354

@@ -401,7 +395,7 @@ class AsanInterceptor {
401395
m_ProgramMap;
402396
ur_shared_mutex m_ProgramMapMutex;
403397

404-
std::unordered_map<ur_kernel_handle_t, std::shared_ptr<KernelInfo>>
398+
std::unordered_map<ur_kernel_handle_t, std::unique_ptr<KernelInfo>>
405399
m_KernelMap;
406400
ur_shared_mutex m_KernelMapMutex;
407401

0 commit comments

Comments
 (0)