Skip to content

Commit 533ab9b

Browse files
authored
Merge pull request #2513 from AllanZyne/review/yang/fix_msan_usm
[DeviceMSAN] Fix "urEnqueueUSM" APIs
2 parents e7366f9 + 3eeb2a1 commit 533ab9b

File tree

6 files changed

+408
-149
lines changed

6 files changed

+408
-149
lines changed

source/loader/layers/sanitizer/msan/msan_ddi.cpp

Lines changed: 257 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@ ur_result_t setupContext(ur_context_handle_t Context, uint32_t numDevices,
4545
UR_CALL(DI->allocShadowMemory(Context));
4646
}
4747
CI->DeviceList.emplace_back(hDevice);
48-
CI->AllocInfosMap[hDevice];
4948
}
5049
return UR_RESULT_SUCCESS;
5150
}
@@ -104,6 +103,17 @@ ur_result_t urUSMDeviceAlloc(
104103
pool, size, ppMem);
105104
}
106105

106+
///////////////////////////////////////////////////////////////////////////////
107+
/// @brief Intercept function for urUSMFree
108+
__urdlllocal ur_result_t UR_APICALL urUSMFree(
109+
ur_context_handle_t hContext, ///< [in] handle of the context object
110+
void *pMem ///< [in] pointer to USM memory object
111+
) {
112+
getContext()->logger.debug("==== urUSMFree");
113+
114+
return getMsanInterceptor()->releaseMemory(hContext, pMem);
115+
}
116+
107117
///////////////////////////////////////////////////////////////////////////////
108118
/// @brief Intercept function for urProgramCreateWithIL
109119
ur_result_t urProgramCreateWithIL(
@@ -1234,6 +1244,247 @@ ur_result_t urKernelSetArgMemObj(
12341244
return UR_RESULT_SUCCESS;
12351245
}
12361246

1247+
///////////////////////////////////////////////////////////////////////////////
1248+
/// @brief Intercept function for urEnqueueUSMFill
1249+
ur_result_t UR_APICALL urEnqueueUSMFill(
1250+
ur_queue_handle_t hQueue, ///< [in] handle of the queue object
1251+
void *pMem, ///< [in][bounds(0, size)] pointer to USM memory object
1252+
size_t
1253+
patternSize, ///< [in] the size in bytes of the pattern. Must be a power of 2 and less
1254+
///< than or equal to width.
1255+
const void
1256+
*pPattern, ///< [in] pointer with the bytes of the pattern to set.
1257+
size_t
1258+
size, ///< [in] size in bytes to be set. Must be a multiple of patternSize.
1259+
uint32_t numEventsInWaitList, ///< [in] size of the event wait list
1260+
const ur_event_handle_t *
1261+
phEventWaitList, ///< [in][optional][range(0, numEventsInWaitList)] pointer to a list of
1262+
///< events that must be complete before this command can be executed.
1263+
///< If nullptr, the numEventsInWaitList must be 0, indicating that this
1264+
///< command does not wait on any event to complete.
1265+
ur_event_handle_t *
1266+
phEvent ///< [out][optional] return an event object that identifies this particular
1267+
///< command instance. If phEventWaitList and phEvent are not NULL, phEvent
1268+
///< must not refer to an element of the phEventWaitList array.
1269+
) {
1270+
auto pfnUSMFill = getContext()->urDdiTable.Enqueue.pfnUSMFill;
1271+
getContext()->logger.debug("==== urEnqueueUSMFill");
1272+
1273+
ur_event_handle_t hEvents[2] = {};
1274+
UR_CALL(pfnUSMFill(hQueue, pMem, patternSize, pPattern, size,
1275+
numEventsInWaitList, phEventWaitList, &hEvents[0]));
1276+
1277+
const auto Mem = (uptr)pMem;
1278+
auto MemInfoItOp = getMsanInterceptor()->findAllocInfoByAddress(Mem);
1279+
if (MemInfoItOp) {
1280+
auto MemInfo = (*MemInfoItOp)->second;
1281+
1282+
const auto &DeviceInfo =
1283+
getMsanInterceptor()->getDeviceInfo(MemInfo->Device);
1284+
const auto MemShadow = DeviceInfo->Shadow->MemToShadow(Mem);
1285+
1286+
UR_CALL(EnqueueUSMBlockingSet(hQueue, (void *)MemShadow, 0, size, 0,
1287+
nullptr, &hEvents[1]));
1288+
}
1289+
1290+
if (phEvent) {
1291+
UR_CALL(getContext()->urDdiTable.Enqueue.pfnEventsWait(
1292+
hQueue, 2, hEvents, phEvent));
1293+
}
1294+
1295+
return UR_RESULT_SUCCESS;
1296+
}
1297+
1298+
///////////////////////////////////////////////////////////////////////////////
1299+
/// @brief Intercept function for urEnqueueUSMMemcpy
1300+
ur_result_t UR_APICALL urEnqueueUSMMemcpy(
1301+
ur_queue_handle_t hQueue, ///< [in] handle of the queue object
1302+
bool blocking, ///< [in] blocking or non-blocking copy
1303+
void *
1304+
pDst, ///< [in][bounds(0, size)] pointer to the destination USM memory object
1305+
const void *
1306+
pSrc, ///< [in][bounds(0, size)] pointer to the source USM memory object
1307+
size_t size, ///< [in] size in bytes to be copied
1308+
uint32_t numEventsInWaitList, ///< [in] size of the event wait list
1309+
const ur_event_handle_t *
1310+
phEventWaitList, ///< [in][optional][range(0, numEventsInWaitList)] pointer to a list of
1311+
///< events that must be complete before this command can be executed.
1312+
///< If nullptr, the numEventsInWaitList must be 0, indicating that this
1313+
///< command does not wait on any event to complete.
1314+
ur_event_handle_t *
1315+
phEvent ///< [out][optional] return an event object that identifies this particular
1316+
///< command instance. If phEventWaitList and phEvent are not NULL, phEvent
1317+
///< must not refer to an element of the phEventWaitList array.
1318+
) {
1319+
auto pfnUSMMemcpy = getContext()->urDdiTable.Enqueue.pfnUSMMemcpy;
1320+
getContext()->logger.debug("==== pfnUSMMemcpy");
1321+
1322+
ur_event_handle_t hEvents[2] = {};
1323+
UR_CALL(pfnUSMMemcpy(hQueue, blocking, pDst, pSrc, size,
1324+
numEventsInWaitList, phEventWaitList, &hEvents[0]));
1325+
1326+
const auto Src = (uptr)pSrc, Dst = (uptr)pDst;
1327+
auto SrcInfoItOp = getMsanInterceptor()->findAllocInfoByAddress(Src);
1328+
auto DstInfoItOp = getMsanInterceptor()->findAllocInfoByAddress(Dst);
1329+
1330+
if (SrcInfoItOp && DstInfoItOp) {
1331+
auto SrcInfo = (*SrcInfoItOp)->second;
1332+
auto DstInfo = (*DstInfoItOp)->second;
1333+
1334+
const auto &DeviceInfo =
1335+
getMsanInterceptor()->getDeviceInfo(SrcInfo->Device);
1336+
const auto SrcShadow = DeviceInfo->Shadow->MemToShadow(Src);
1337+
const auto DstShadow = DeviceInfo->Shadow->MemToShadow(Dst);
1338+
1339+
UR_CALL(pfnUSMMemcpy(hQueue, blocking, (void *)DstShadow,
1340+
(void *)SrcShadow, size, 0, nullptr, &hEvents[1]));
1341+
} else if (DstInfoItOp) {
1342+
auto DstInfo = (*DstInfoItOp)->second;
1343+
1344+
const auto &DeviceInfo =
1345+
getMsanInterceptor()->getDeviceInfo(DstInfo->Device);
1346+
auto DstShadow = DeviceInfo->Shadow->MemToShadow(Dst);
1347+
1348+
UR_CALL(EnqueueUSMBlockingSet(hQueue, (void *)DstShadow, 0, size, 0,
1349+
nullptr, &hEvents[1]));
1350+
}
1351+
1352+
if (phEvent) {
1353+
UR_CALL(getContext()->urDdiTable.Enqueue.pfnEventsWait(
1354+
hQueue, 2, hEvents, phEvent));
1355+
}
1356+
1357+
return UR_RESULT_SUCCESS;
1358+
}
1359+
1360+
///////////////////////////////////////////////////////////////////////////////
1361+
/// @brief Intercept function for urEnqueueUSMFill2D
1362+
ur_result_t UR_APICALL urEnqueueUSMFill2D(
1363+
ur_queue_handle_t hQueue, ///< [in] handle of the queue to submit to.
1364+
void *
1365+
pMem, ///< [in][bounds(0, pitch * height)] pointer to memory to be filled.
1366+
size_t
1367+
pitch, ///< [in] the total width of the destination memory including padding.
1368+
size_t
1369+
patternSize, ///< [in] the size in bytes of the pattern. Must be a power of 2 and less
1370+
///< than or equal to width.
1371+
const void
1372+
*pPattern, ///< [in] pointer with the bytes of the pattern to set.
1373+
size_t
1374+
width, ///< [in] the width in bytes of each row to fill. Must be a multiple of
1375+
///< patternSize.
1376+
size_t height, ///< [in] the height of the columns to fill.
1377+
uint32_t numEventsInWaitList, ///< [in] size of the event wait list
1378+
const ur_event_handle_t *
1379+
phEventWaitList, ///< [in][optional][range(0, numEventsInWaitList)] pointer to a list of
1380+
///< events that must be complete before the kernel execution.
1381+
///< If nullptr, the numEventsInWaitList must be 0, indicating that no wait event.
1382+
ur_event_handle_t *
1383+
phEvent ///< [out][optional] return an event object that identifies this particular
1384+
///< kernel execution instance. If phEventWaitList and phEvent are not
1385+
///< NULL, phEvent must not refer to an element of the phEventWaitList array.
1386+
) {
1387+
auto pfnUSMFill2D = getContext()->urDdiTable.Enqueue.pfnUSMFill2D;
1388+
getContext()->logger.debug("==== urEnqueueUSMFill2D");
1389+
1390+
ur_event_handle_t hEvents[2] = {};
1391+
UR_CALL(pfnUSMFill2D(hQueue, pMem, pitch, patternSize, pPattern, width,
1392+
height, numEventsInWaitList, phEventWaitList,
1393+
&hEvents[0]));
1394+
1395+
const auto Mem = (uptr)pMem;
1396+
auto MemInfoItOp = getMsanInterceptor()->findAllocInfoByAddress(Mem);
1397+
if (MemInfoItOp) {
1398+
auto MemInfo = (*MemInfoItOp)->second;
1399+
1400+
const auto &DeviceInfo =
1401+
getMsanInterceptor()->getDeviceInfo(MemInfo->Device);
1402+
const auto MemShadow = DeviceInfo->Shadow->MemToShadow(Mem);
1403+
1404+
const char Pattern = 0;
1405+
UR_CALL(pfnUSMFill2D(hQueue, (void *)MemShadow, pitch, 1, &Pattern,
1406+
width, height, 0, nullptr, &hEvents[1]));
1407+
}
1408+
1409+
if (phEvent) {
1410+
UR_CALL(getContext()->urDdiTable.Enqueue.pfnEventsWait(
1411+
hQueue, 2, hEvents, phEvent));
1412+
}
1413+
1414+
return UR_RESULT_SUCCESS;
1415+
}
1416+
1417+
///////////////////////////////////////////////////////////////////////////////
1418+
/// @brief Intercept function for urEnqueueUSMMemcpy2D
1419+
ur_result_t UR_APICALL urEnqueueUSMMemcpy2D(
1420+
ur_queue_handle_t hQueue, ///< [in] handle of the queue to submit to.
1421+
bool blocking, ///< [in] indicates if this operation should block the host.
1422+
void *
1423+
pDst, ///< [in][bounds(0, dstPitch * height)] pointer to memory where data will
1424+
///< be copied.
1425+
size_t
1426+
dstPitch, ///< [in] the total width of the source memory including padding.
1427+
const void *
1428+
pSrc, ///< [in][bounds(0, srcPitch * height)] pointer to memory to be copied.
1429+
size_t
1430+
srcPitch, ///< [in] the total width of the source memory including padding.
1431+
size_t width, ///< [in] the width in bytes of each row to be copied.
1432+
size_t height, ///< [in] the height of columns to be copied.
1433+
uint32_t numEventsInWaitList, ///< [in] size of the event wait list
1434+
const ur_event_handle_t *
1435+
phEventWaitList, ///< [in][optional][range(0, numEventsInWaitList)] pointer to a list of
1436+
///< events that must be complete before the kernel execution.
1437+
///< If nullptr, the numEventsInWaitList must be 0, indicating that no wait event.
1438+
ur_event_handle_t *
1439+
phEvent ///< [out][optional] return an event object that identifies this particular
1440+
///< kernel execution instance. If phEventWaitList and phEvent are not
1441+
///< NULL, phEvent must not refer to an element of the phEventWaitList array.
1442+
) {
1443+
auto pfnUSMMemcpy2D = getContext()->urDdiTable.Enqueue.pfnUSMMemcpy2D;
1444+
getContext()->logger.debug("==== pfnUSMMemcpy2D");
1445+
1446+
ur_event_handle_t hEvents[2] = {};
1447+
UR_CALL(pfnUSMMemcpy2D(hQueue, blocking, pDst, dstPitch, pSrc, srcPitch,
1448+
width, height, numEventsInWaitList, phEventWaitList,
1449+
&hEvents[0]));
1450+
1451+
const auto Src = (uptr)pSrc, Dst = (uptr)pDst;
1452+
auto SrcInfoItOp = getMsanInterceptor()->findAllocInfoByAddress(Src);
1453+
auto DstInfoItOp = getMsanInterceptor()->findAllocInfoByAddress(Dst);
1454+
1455+
if (SrcInfoItOp && DstInfoItOp) {
1456+
auto SrcInfo = (*SrcInfoItOp)->second;
1457+
auto DstInfo = (*DstInfoItOp)->second;
1458+
1459+
const auto &DeviceInfo =
1460+
getMsanInterceptor()->getDeviceInfo(SrcInfo->Device);
1461+
const auto SrcShadow = DeviceInfo->Shadow->MemToShadow(Src);
1462+
const auto DstShadow = DeviceInfo->Shadow->MemToShadow(Dst);
1463+
1464+
UR_CALL(pfnUSMMemcpy2D(hQueue, blocking, (void *)DstShadow, dstPitch,
1465+
(void *)SrcShadow, srcPitch, width, height, 0,
1466+
nullptr, &hEvents[1]));
1467+
} else if (DstInfoItOp) {
1468+
auto DstInfo = (*DstInfoItOp)->second;
1469+
1470+
const auto &DeviceInfo =
1471+
getMsanInterceptor()->getDeviceInfo(DstInfo->Device);
1472+
const auto DstShadow = DeviceInfo->Shadow->MemToShadow(Dst);
1473+
1474+
const char Pattern = 0;
1475+
UR_CALL(getContext()->urDdiTable.Enqueue.pfnUSMFill2D(
1476+
hQueue, (void *)DstShadow, dstPitch, 1, &Pattern, width, height, 0,
1477+
nullptr, &hEvents[1]));
1478+
}
1479+
1480+
if (phEvent) {
1481+
UR_CALL(getContext()->urDdiTable.Enqueue.pfnEventsWait(
1482+
hQueue, 2, hEvents, phEvent));
1483+
}
1484+
1485+
return UR_RESULT_SUCCESS;
1486+
}
1487+
12371488
///////////////////////////////////////////////////////////////////////////////
12381489
/// @brief Exported function for filling application's Global table
12391490
/// with current process' addresses
@@ -1391,6 +1642,10 @@ ur_result_t urGetEnqueueProcAddrTable(
13911642
pDdiTable->pfnMemUnmap = ur_sanitizer_layer::msan::urEnqueueMemUnmap;
13921643
pDdiTable->pfnKernelLaunch =
13931644
ur_sanitizer_layer::msan::urEnqueueKernelLaunch;
1645+
pDdiTable->pfnUSMFill = ur_sanitizer_layer::msan::urEnqueueUSMFill;
1646+
pDdiTable->pfnUSMMemcpy = ur_sanitizer_layer::msan::urEnqueueUSMMemcpy;
1647+
pDdiTable->pfnUSMFill2D = ur_sanitizer_layer::msan::urEnqueueUSMFill2D;
1648+
pDdiTable->pfnUSMMemcpy2D = ur_sanitizer_layer::msan::urEnqueueUSMMemcpy2D;
13941649

13951650
return result;
13961651
}
@@ -1408,6 +1663,7 @@ ur_result_t urGetUSMProcAddrTable(
14081663
ur_result_t result = UR_RESULT_SUCCESS;
14091664

14101665
pDdiTable->pfnDeviceAlloc = ur_sanitizer_layer::msan::urUSMDeviceAlloc;
1666+
pDdiTable->pfnFree = ur_sanitizer_layer::msan::urUSMFree;
14111667

14121668
return result;
14131669
}

0 commit comments

Comments
 (0)