Skip to content

Commit 20c8b2a

Browse files
committed
sync asan to msan
1 parent 4022e19 commit 20c8b2a

File tree

4 files changed

+45
-74
lines changed

4 files changed

+45
-74
lines changed

source/loader/layers/sanitizer/asan/asan_interceptor.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -648,8 +648,8 @@ KernelInfo &AsanInterceptor::getOrCreateKernelInfo(ur_kernel_handle_t Kernel) {
648648
}
649649

650650
// Create new KernelInfo
651-
auto hProgram = GetProgram(Kernel);
652-
auto PI = getAsanInterceptor()->getProgramInfo(hProgram);
651+
auto Program = GetProgram(Kernel);
652+
auto PI = getProgramInfo(Program);
653653
bool IsInstrumented = PI->isKernelInstrumented(Kernel);
654654

655655
std::scoped_lock<ur_shared_mutex> Guard(m_KernelMapMutex);

source/loader/layers/sanitizer/msan/msan_ddi.cpp

Lines changed: 13 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -50,12 +50,6 @@ ur_result_t setupContext(ur_context_handle_t Context, uint32_t numDevices,
5050
return UR_RESULT_SUCCESS;
5151
}
5252

53-
bool isInstrumentedKernel(ur_kernel_handle_t hKernel) {
54-
auto hProgram = GetProgram(hKernel);
55-
auto PI = getMsanInterceptor()->getProgramInfo(hProgram);
56-
return PI->isKernelInstrumented(hKernel);
57-
}
58-
5953
} // namespace
6054

6155
///////////////////////////////////////////////////////////////////////////////
@@ -354,12 +348,6 @@ ur_result_t urEnqueueKernelLaunch(
354348

355349
getContext()->logger.debug("==== urEnqueueKernelLaunch");
356350

357-
if (!isInstrumentedKernel(hKernel)) {
358-
return pfnKernelLaunch(hQueue, hKernel, workDim, pGlobalWorkOffset,
359-
pGlobalWorkSize, pLocalWorkSize,
360-
numEventsInWaitList, phEventWaitList, phEvent);
361-
}
362-
363351
USMLaunchInfo LaunchInfo(GetContext(hQueue), GetDevice(hQueue),
364352
pGlobalWorkSize, pLocalWorkSize, pGlobalWorkOffset,
365353
workDim);
@@ -1155,26 +1143,6 @@ ur_result_t urEnqueueMemUnmap(
11551143
return UR_RESULT_SUCCESS;
11561144
}
11571145

1158-
///////////////////////////////////////////////////////////////////////////////
1159-
/// @brief Intercept function for urKernelCreate
1160-
ur_result_t urKernelCreate(
1161-
ur_program_handle_t hProgram, ///< [in] handle of the program instance
1162-
const char *pKernelName, ///< [in] pointer to null-terminated string.
1163-
ur_kernel_handle_t
1164-
*phKernel ///< [out] pointer to handle of kernel object created.
1165-
) {
1166-
auto pfnCreate = getContext()->urDdiTable.Kernel.pfnCreate;
1167-
1168-
getContext()->logger.debug("==== urKernelCreate");
1169-
1170-
UR_CALL(pfnCreate(hProgram, pKernelName, phKernel));
1171-
if (isInstrumentedKernel(*phKernel)) {
1172-
UR_CALL(getMsanInterceptor()->insertKernel(*phKernel));
1173-
}
1174-
1175-
return UR_RESULT_SUCCESS;
1176-
}
1177-
11781146
///////////////////////////////////////////////////////////////////////////////
11791147
/// @brief Intercept function for urKernelRetain
11801148
ur_result_t urKernelRetain(
@@ -1186,10 +1154,8 @@ ur_result_t urKernelRetain(
11861154

11871155
UR_CALL(pfnRetain(hKernel));
11881156

1189-
auto KernelInfo = getMsanInterceptor()->getKernelInfo(hKernel);
1190-
if (KernelInfo) {
1191-
KernelInfo->RefCount++;
1192-
}
1157+
auto &KernelInfo = getMsanInterceptor()->getOrCreateKernelInfo(hKernel);
1158+
KernelInfo.RefCount++;
11931159

11941160
return UR_RESULT_SUCCESS;
11951161
}
@@ -1204,11 +1170,9 @@ ur_result_t urKernelRelease(
12041170
getContext()->logger.debug("==== urKernelRelease");
12051171
UR_CALL(pfnRelease(hKernel));
12061172

1207-
auto KernelInfo = getMsanInterceptor()->getKernelInfo(hKernel);
1208-
if (KernelInfo) {
1209-
if (--KernelInfo->RefCount == 0) {
1210-
UR_CALL(getMsanInterceptor()->eraseKernel(hKernel));
1211-
}
1173+
auto &KernelInfo = getMsanInterceptor()->getOrCreateKernelInfo(hKernel);
1174+
if (--KernelInfo.RefCount == 0) {
1175+
UR_CALL(getMsanInterceptor()->eraseKernelInfo(hKernel));
12121176
}
12131177

12141178
return UR_RESULT_SUCCESS;
@@ -1230,13 +1194,12 @@ ur_result_t urKernelSetArgValue(
12301194
getContext()->logger.debug("==== urKernelSetArgValue");
12311195

12321196
std::shared_ptr<MemBuffer> MemBuffer;
1233-
std::shared_ptr<KernelInfo> KernelInfo;
12341197
if (argSize == sizeof(ur_mem_handle_t) &&
12351198
(MemBuffer = getMsanInterceptor()->getMemBuffer(
1236-
*ur_cast<const ur_mem_handle_t *>(pArgValue))) &&
1237-
(KernelInfo = getMsanInterceptor()->getKernelInfo(hKernel))) {
1238-
std::scoped_lock<ur_shared_mutex> Guard(KernelInfo->Mutex);
1239-
KernelInfo->BufferArgs[argIndex] = std::move(MemBuffer);
1199+
*ur_cast<const ur_mem_handle_t *>(pArgValue)))) {
1200+
auto &KernelInfo = getMsanInterceptor()->getOrCreateKernelInfo(hKernel);
1201+
std::scoped_lock<ur_shared_mutex> Guard(KernelInfo.Mutex);
1202+
KernelInfo.BufferArgs[argIndex] = std::move(MemBuffer);
12401203
} else {
12411204
UR_CALL(
12421205
pfnSetArgValue(hKernel, argIndex, argSize, pProperties, pArgValue));
@@ -1260,10 +1223,10 @@ ur_result_t urKernelSetArgMemObj(
12601223

12611224
std::shared_ptr<MemBuffer> MemBuffer;
12621225
std::shared_ptr<KernelInfo> KernelInfo;
1263-
if ((MemBuffer = getMsanInterceptor()->getMemBuffer(hArgValue)) &&
1264-
(KernelInfo = getMsanInterceptor()->getKernelInfo(hKernel))) {
1265-
std::scoped_lock<ur_shared_mutex> Guard(KernelInfo->Mutex);
1266-
KernelInfo->BufferArgs[argIndex] = std::move(MemBuffer);
1226+
if ((MemBuffer = getMsanInterceptor()->getMemBuffer(hArgValue))) {
1227+
auto &KernelInfo = getMsanInterceptor()->getOrCreateKernelInfo(hKernel);
1228+
std::scoped_lock<ur_shared_mutex> Guard(KernelInfo.Mutex);
1229+
KernelInfo.BufferArgs[argIndex] = std::move(MemBuffer);
12671230
} else {
12681231
UR_CALL(pfnSetArgMemObj(hKernel, argIndex, pProperties, hArgValue));
12691232
}
@@ -1348,7 +1311,6 @@ ur_result_t urGetKernelProcAddrTable(
13481311
) {
13491312
ur_result_t result = UR_RESULT_SUCCESS;
13501313

1351-
pDdiTable->pfnCreate = ur_sanitizer_layer::msan::urKernelCreate;
13521314
pDdiTable->pfnRetain = ur_sanitizer_layer::msan::urKernelRetain;
13531315
pDdiTable->pfnRelease = ur_sanitizer_layer::msan::urKernelRelease;
13541316
pDdiTable->pfnSetArgValue = ur_sanitizer_layer::msan::urKernelSetArgValue;

source/loader/layers/sanitizer/msan/msan_interceptor.cpp

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -298,16 +298,26 @@ ur_result_t MsanInterceptor::eraseProgram(ur_program_handle_t Program) {
298298
return UR_RESULT_SUCCESS;
299299
}
300300

301-
ur_result_t MsanInterceptor::insertKernel(ur_kernel_handle_t Kernel) {
302-
std::scoped_lock<ur_shared_mutex> Guard(m_KernelMapMutex);
303-
if (m_KernelMap.find(Kernel) != m_KernelMap.end()) {
304-
return UR_RESULT_SUCCESS;
301+
KernelInfo &MsanInterceptor::getOrCreateKernelInfo(ur_kernel_handle_t Kernel) {
302+
{
303+
std::shared_lock<ur_shared_mutex> Guard(m_KernelMapMutex);
304+
if (m_KernelMap.find(Kernel) != m_KernelMap.end()) {
305+
return *m_KernelMap[Kernel].get();
306+
}
305307
}
306-
m_KernelMap.emplace(Kernel, std::make_shared<KernelInfo>(Kernel));
307-
return UR_RESULT_SUCCESS;
308+
309+
// Create new KernelInfo
310+
auto Program = GetProgram(Kernel);
311+
auto PI = getProgramInfo(Program);
312+
bool IsInstrumented = PI->isKernelInstrumented(Kernel);
313+
314+
std::scoped_lock<ur_shared_mutex> Guard(m_KernelMapMutex);
315+
m_KernelMap.emplace(Kernel,
316+
std::make_unique<KernelInfo>(Kernel, IsInstrumented));
317+
return *m_KernelMap[Kernel].get();
308318
}
309319

310-
ur_result_t MsanInterceptor::eraseKernel(ur_kernel_handle_t Kernel) {
320+
ur_result_t MsanInterceptor::eraseKernelInfo(ur_kernel_handle_t Kernel) {
311321
std::scoped_lock<ur_shared_mutex> Guard(m_KernelMapMutex);
312322
assert(m_KernelMap.find(Kernel) != m_KernelMap.end());
313323
m_KernelMap.erase(Kernel);
@@ -360,10 +370,9 @@ ur_result_t MsanInterceptor::prepareLaunch(
360370
};
361371

362372
// Set membuffer arguments
363-
auto KernelInfo = getKernelInfo(Kernel);
364-
assert(KernelInfo && "Kernel must be instrumented");
373+
auto &KernelInfo = getOrCreateKernelInfo(Kernel);
365374

366-
for (const auto &[ArgIndex, MemBuffer] : KernelInfo->BufferArgs) {
375+
for (const auto &[ArgIndex, MemBuffer] : KernelInfo.BufferArgs) {
367376
char *ArgPointer = nullptr;
368377
UR_CALL(MemBuffer->getHandle(DeviceInfo->Handle, ArgPointer));
369378
ur_result_t URes = getContext()->urDdiTable.Kernel.pfnSetArgPointer(
@@ -376,6 +385,10 @@ ur_result_t MsanInterceptor::prepareLaunch(
376385
}
377386
}
378387

388+
if (!KernelInfo.IsInstrumented) {
389+
return UR_RESULT_SUCCESS;
390+
}
391+
379392
// Set LaunchInfo
380393
LaunchInfo.Data->GlobalShadowOffset = DeviceInfo->Shadow->ShadowBegin;
381394
LaunchInfo.Data->GlobalShadowOffsetEnd = DeviceInfo->Shadow->ShadowEnd;

source/loader/layers/sanitizer/msan/msan_interceptor.hpp

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -76,11 +76,15 @@ struct KernelInfo {
7676
ur_kernel_handle_t Handle;
7777
std::atomic<int32_t> RefCount = 1;
7878

79+
// sanitized kernel
80+
bool IsInstrumented = false;
81+
7982
// lock this mutex if following fields are accessed
8083
ur_shared_mutex Mutex;
8184
std::unordered_map<uint32_t, std::shared_ptr<MemBuffer>> BufferArgs;
8285

83-
explicit KernelInfo(ur_kernel_handle_t Kernel) : Handle(Kernel) {
86+
explicit KernelInfo(ur_kernel_handle_t Kernel, bool IsInstrumented)
87+
: Handle(Kernel), IsInstrumented(IsInstrumented) {
8488
[[maybe_unused]] auto Result =
8589
getContext()->urDdiTable.Kernel.pfnRetain(Kernel);
8690
assert(Result == UR_RESULT_SUCCESS);
@@ -203,9 +207,6 @@ class MsanInterceptor {
203207
ur_result_t insertProgram(ur_program_handle_t Program);
204208
ur_result_t eraseProgram(ur_program_handle_t Program);
205209

206-
ur_result_t insertKernel(ur_kernel_handle_t Kernel);
207-
ur_result_t eraseKernel(ur_kernel_handle_t Kernel);
208-
209210
ur_result_t insertMemBuffer(std::shared_ptr<MemBuffer> MemBuffer);
210211
ur_result_t eraseMemBuffer(ur_mem_handle_t MemHandle);
211212
std::shared_ptr<MemBuffer> getMemBuffer(ur_mem_handle_t MemHandle);
@@ -245,13 +246,8 @@ class MsanInterceptor {
245246
return m_ProgramMap[Program];
246247
}
247248

248-
std::shared_ptr<msan::KernelInfo> getKernelInfo(ur_kernel_handle_t Kernel) {
249-
std::shared_lock<ur_shared_mutex> Guard(m_KernelMapMutex);
250-
if (m_KernelMap.find(Kernel) != m_KernelMap.end()) {
251-
return m_KernelMap[Kernel];
252-
}
253-
return nullptr;
254-
}
249+
KernelInfo &getOrCreateKernelInfo(ur_kernel_handle_t Kernel);
250+
ur_result_t eraseKernelInfo(ur_kernel_handle_t Kernel);
255251

256252
const MsanOptions &getOptions() { return m_Options; }
257253

0 commit comments

Comments
 (0)