@@ -45,7 +45,6 @@ ur_result_t setupContext(ur_context_handle_t Context, uint32_t numDevices,
45
45
UR_CALL (DI->allocShadowMemory (Context));
46
46
}
47
47
CI->DeviceList .emplace_back (hDevice);
48
- CI->AllocInfosMap [hDevice];
49
48
}
50
49
return UR_RESULT_SUCCESS;
51
50
}
@@ -104,6 +103,17 @@ ur_result_t urUSMDeviceAlloc(
104
103
pool, size, ppMem);
105
104
}
106
105
106
+ // /////////////////////////////////////////////////////////////////////////////
107
+ // / @brief Intercept function for urUSMFree
108
+ __urdlllocal ur_result_t UR_APICALL urUSMFree (
109
+ ur_context_handle_t hContext, // /< [in] handle of the context object
110
+ void *pMem // /< [in] pointer to USM memory object
111
+ ) {
112
+ getContext ()->logger .debug (" ==== urUSMFree" );
113
+
114
+ return getMsanInterceptor ()->releaseMemory (hContext, pMem);
115
+ }
116
+
107
117
// /////////////////////////////////////////////////////////////////////////////
108
118
// / @brief Intercept function for urProgramCreateWithIL
109
119
ur_result_t urProgramCreateWithIL (
@@ -1234,6 +1244,247 @@ ur_result_t urKernelSetArgMemObj(
1234
1244
return UR_RESULT_SUCCESS;
1235
1245
}
1236
1246
1247
+ // /////////////////////////////////////////////////////////////////////////////
1248
+ // / @brief Intercept function for urEnqueueUSMFill
1249
+ ur_result_t UR_APICALL urEnqueueUSMFill (
1250
+ ur_queue_handle_t hQueue, // /< [in] handle of the queue object
1251
+ void *pMem, // /< [in][bounds(0, size)] pointer to USM memory object
1252
+ size_t
1253
+ patternSize, // /< [in] the size in bytes of the pattern. Must be a power of 2 and less
1254
+ // /< than or equal to width.
1255
+ const void
1256
+ *pPattern, // /< [in] pointer with the bytes of the pattern to set.
1257
+ size_t
1258
+ size, // /< [in] size in bytes to be set. Must be a multiple of patternSize.
1259
+ uint32_t numEventsInWaitList, // /< [in] size of the event wait list
1260
+ const ur_event_handle_t *
1261
+ phEventWaitList, // /< [in][optional][range(0, numEventsInWaitList)] pointer to a list of
1262
+ // /< events that must be complete before this command can be executed.
1263
+ // /< If nullptr, the numEventsInWaitList must be 0, indicating that this
1264
+ // /< command does not wait on any event to complete.
1265
+ ur_event_handle_t *
1266
+ phEvent // /< [out][optional] return an event object that identifies this particular
1267
+ // /< command instance. If phEventWaitList and phEvent are not NULL, phEvent
1268
+ // /< must not refer to an element of the phEventWaitList array.
1269
+ ) {
1270
+ auto pfnUSMFill = getContext ()->urDdiTable .Enqueue .pfnUSMFill ;
1271
+ getContext ()->logger .debug (" ==== urEnqueueUSMFill" );
1272
+
1273
+ ur_event_handle_t hEvents[2 ] = {};
1274
+ UR_CALL (pfnUSMFill (hQueue, pMem, patternSize, pPattern, size,
1275
+ numEventsInWaitList, phEventWaitList, &hEvents[0 ]));
1276
+
1277
+ const auto Mem = (uptr)pMem;
1278
+ auto MemInfoItOp = getMsanInterceptor ()->findAllocInfoByAddress (Mem);
1279
+ if (MemInfoItOp) {
1280
+ auto MemInfo = (*MemInfoItOp)->second ;
1281
+
1282
+ const auto &DeviceInfo =
1283
+ getMsanInterceptor ()->getDeviceInfo (MemInfo->Device );
1284
+ const auto MemShadow = DeviceInfo->Shadow ->MemToShadow (Mem);
1285
+
1286
+ UR_CALL (EnqueueUSMBlockingSet (hQueue, (void *)MemShadow, 0 , size, 0 ,
1287
+ nullptr , &hEvents[1 ]));
1288
+ }
1289
+
1290
+ if (phEvent) {
1291
+ UR_CALL (getContext ()->urDdiTable .Enqueue .pfnEventsWait (
1292
+ hQueue, 2 , hEvents, phEvent));
1293
+ }
1294
+
1295
+ return UR_RESULT_SUCCESS;
1296
+ }
1297
+
1298
+ // /////////////////////////////////////////////////////////////////////////////
1299
+ // / @brief Intercept function for urEnqueueUSMMemcpy
1300
+ ur_result_t UR_APICALL urEnqueueUSMMemcpy (
1301
+ ur_queue_handle_t hQueue, // /< [in] handle of the queue object
1302
+ bool blocking, // /< [in] blocking or non-blocking copy
1303
+ void *
1304
+ pDst, // /< [in][bounds(0, size)] pointer to the destination USM memory object
1305
+ const void *
1306
+ pSrc, // /< [in][bounds(0, size)] pointer to the source USM memory object
1307
+ size_t size, // /< [in] size in bytes to be copied
1308
+ uint32_t numEventsInWaitList, // /< [in] size of the event wait list
1309
+ const ur_event_handle_t *
1310
+ phEventWaitList, // /< [in][optional][range(0, numEventsInWaitList)] pointer to a list of
1311
+ // /< events that must be complete before this command can be executed.
1312
+ // /< If nullptr, the numEventsInWaitList must be 0, indicating that this
1313
+ // /< command does not wait on any event to complete.
1314
+ ur_event_handle_t *
1315
+ phEvent // /< [out][optional] return an event object that identifies this particular
1316
+ // /< command instance. If phEventWaitList and phEvent are not NULL, phEvent
1317
+ // /< must not refer to an element of the phEventWaitList array.
1318
+ ) {
1319
+ auto pfnUSMMemcpy = getContext ()->urDdiTable .Enqueue .pfnUSMMemcpy ;
1320
+ getContext ()->logger .debug (" ==== pfnUSMMemcpy" );
1321
+
1322
+ ur_event_handle_t hEvents[2 ] = {};
1323
+ UR_CALL (pfnUSMMemcpy (hQueue, blocking, pDst, pSrc, size,
1324
+ numEventsInWaitList, phEventWaitList, &hEvents[0 ]));
1325
+
1326
+ const auto Src = (uptr)pSrc, Dst = (uptr)pDst;
1327
+ auto SrcInfoItOp = getMsanInterceptor ()->findAllocInfoByAddress (Src);
1328
+ auto DstInfoItOp = getMsanInterceptor ()->findAllocInfoByAddress (Dst);
1329
+
1330
+ if (SrcInfoItOp && DstInfoItOp) {
1331
+ auto SrcInfo = (*SrcInfoItOp)->second ;
1332
+ auto DstInfo = (*DstInfoItOp)->second ;
1333
+
1334
+ const auto &DeviceInfo =
1335
+ getMsanInterceptor ()->getDeviceInfo (SrcInfo->Device );
1336
+ const auto SrcShadow = DeviceInfo->Shadow ->MemToShadow (Src);
1337
+ const auto DstShadow = DeviceInfo->Shadow ->MemToShadow (Dst);
1338
+
1339
+ UR_CALL (pfnUSMMemcpy (hQueue, blocking, (void *)DstShadow,
1340
+ (void *)SrcShadow, size, 0 , nullptr , &hEvents[1 ]));
1341
+ } else if (DstInfoItOp) {
1342
+ auto DstInfo = (*DstInfoItOp)->second ;
1343
+
1344
+ const auto &DeviceInfo =
1345
+ getMsanInterceptor ()->getDeviceInfo (DstInfo->Device );
1346
+ auto DstShadow = DeviceInfo->Shadow ->MemToShadow (Dst);
1347
+
1348
+ UR_CALL (EnqueueUSMBlockingSet (hQueue, (void *)DstShadow, 0 , size, 0 ,
1349
+ nullptr , &hEvents[1 ]));
1350
+ }
1351
+
1352
+ if (phEvent) {
1353
+ UR_CALL (getContext ()->urDdiTable .Enqueue .pfnEventsWait (
1354
+ hQueue, 2 , hEvents, phEvent));
1355
+ }
1356
+
1357
+ return UR_RESULT_SUCCESS;
1358
+ }
1359
+
1360
+ // /////////////////////////////////////////////////////////////////////////////
1361
+ // / @brief Intercept function for urEnqueueUSMFill2D
1362
+ ur_result_t UR_APICALL urEnqueueUSMFill2D (
1363
+ ur_queue_handle_t hQueue, // /< [in] handle of the queue to submit to.
1364
+ void *
1365
+ pMem, // /< [in][bounds(0, pitch * height)] pointer to memory to be filled.
1366
+ size_t
1367
+ pitch, // /< [in] the total width of the destination memory including padding.
1368
+ size_t
1369
+ patternSize, // /< [in] the size in bytes of the pattern. Must be a power of 2 and less
1370
+ // /< than or equal to width.
1371
+ const void
1372
+ *pPattern, // /< [in] pointer with the bytes of the pattern to set.
1373
+ size_t
1374
+ width, // /< [in] the width in bytes of each row to fill. Must be a multiple of
1375
+ // /< patternSize.
1376
+ size_t height, // /< [in] the height of the columns to fill.
1377
+ uint32_t numEventsInWaitList, // /< [in] size of the event wait list
1378
+ const ur_event_handle_t *
1379
+ phEventWaitList, // /< [in][optional][range(0, numEventsInWaitList)] pointer to a list of
1380
+ // /< events that must be complete before the kernel execution.
1381
+ // /< If nullptr, the numEventsInWaitList must be 0, indicating that no wait event.
1382
+ ur_event_handle_t *
1383
+ phEvent // /< [out][optional] return an event object that identifies this particular
1384
+ // /< kernel execution instance. If phEventWaitList and phEvent are not
1385
+ // /< NULL, phEvent must not refer to an element of the phEventWaitList array.
1386
+ ) {
1387
+ auto pfnUSMFill2D = getContext ()->urDdiTable .Enqueue .pfnUSMFill2D ;
1388
+ getContext ()->logger .debug (" ==== urEnqueueUSMFill2D" );
1389
+
1390
+ ur_event_handle_t hEvents[2 ] = {};
1391
+ UR_CALL (pfnUSMFill2D (hQueue, pMem, pitch, patternSize, pPattern, width,
1392
+ height, numEventsInWaitList, phEventWaitList,
1393
+ &hEvents[0 ]));
1394
+
1395
+ const auto Mem = (uptr)pMem;
1396
+ auto MemInfoItOp = getMsanInterceptor ()->findAllocInfoByAddress (Mem);
1397
+ if (MemInfoItOp) {
1398
+ auto MemInfo = (*MemInfoItOp)->second ;
1399
+
1400
+ const auto &DeviceInfo =
1401
+ getMsanInterceptor ()->getDeviceInfo (MemInfo->Device );
1402
+ const auto MemShadow = DeviceInfo->Shadow ->MemToShadow (Mem);
1403
+
1404
+ const char Pattern = 0 ;
1405
+ UR_CALL (pfnUSMFill2D (hQueue, (void *)MemShadow, pitch, 1 , &Pattern,
1406
+ width, height, 0 , nullptr , &hEvents[1 ]));
1407
+ }
1408
+
1409
+ if (phEvent) {
1410
+ UR_CALL (getContext ()->urDdiTable .Enqueue .pfnEventsWait (
1411
+ hQueue, 2 , hEvents, phEvent));
1412
+ }
1413
+
1414
+ return UR_RESULT_SUCCESS;
1415
+ }
1416
+
1417
+ // /////////////////////////////////////////////////////////////////////////////
1418
+ // / @brief Intercept function for urEnqueueUSMMemcpy2D
1419
+ ur_result_t UR_APICALL urEnqueueUSMMemcpy2D (
1420
+ ur_queue_handle_t hQueue, // /< [in] handle of the queue to submit to.
1421
+ bool blocking, // /< [in] indicates if this operation should block the host.
1422
+ void *
1423
+ pDst, // /< [in][bounds(0, dstPitch * height)] pointer to memory where data will
1424
+ // /< be copied.
1425
+ size_t
1426
+ dstPitch, // /< [in] the total width of the source memory including padding.
1427
+ const void *
1428
+ pSrc, // /< [in][bounds(0, srcPitch * height)] pointer to memory to be copied.
1429
+ size_t
1430
+ srcPitch, // /< [in] the total width of the source memory including padding.
1431
+ size_t width, // /< [in] the width in bytes of each row to be copied.
1432
+ size_t height, // /< [in] the height of columns to be copied.
1433
+ uint32_t numEventsInWaitList, // /< [in] size of the event wait list
1434
+ const ur_event_handle_t *
1435
+ phEventWaitList, // /< [in][optional][range(0, numEventsInWaitList)] pointer to a list of
1436
+ // /< events that must be complete before the kernel execution.
1437
+ // /< If nullptr, the numEventsInWaitList must be 0, indicating that no wait event.
1438
+ ur_event_handle_t *
1439
+ phEvent // /< [out][optional] return an event object that identifies this particular
1440
+ // /< kernel execution instance. If phEventWaitList and phEvent are not
1441
+ // /< NULL, phEvent must not refer to an element of the phEventWaitList array.
1442
+ ) {
1443
+ auto pfnUSMMemcpy2D = getContext ()->urDdiTable .Enqueue .pfnUSMMemcpy2D ;
1444
+ getContext ()->logger .debug (" ==== pfnUSMMemcpy2D" );
1445
+
1446
+ ur_event_handle_t hEvents[2 ] = {};
1447
+ UR_CALL (pfnUSMMemcpy2D (hQueue, blocking, pDst, dstPitch, pSrc, srcPitch,
1448
+ width, height, numEventsInWaitList, phEventWaitList,
1449
+ &hEvents[0 ]));
1450
+
1451
+ const auto Src = (uptr)pSrc, Dst = (uptr)pDst;
1452
+ auto SrcInfoItOp = getMsanInterceptor ()->findAllocInfoByAddress (Src);
1453
+ auto DstInfoItOp = getMsanInterceptor ()->findAllocInfoByAddress (Dst);
1454
+
1455
+ if (SrcInfoItOp && DstInfoItOp) {
1456
+ auto SrcInfo = (*SrcInfoItOp)->second ;
1457
+ auto DstInfo = (*DstInfoItOp)->second ;
1458
+
1459
+ const auto &DeviceInfo =
1460
+ getMsanInterceptor ()->getDeviceInfo (SrcInfo->Device );
1461
+ const auto SrcShadow = DeviceInfo->Shadow ->MemToShadow (Src);
1462
+ const auto DstShadow = DeviceInfo->Shadow ->MemToShadow (Dst);
1463
+
1464
+ UR_CALL (pfnUSMMemcpy2D (hQueue, blocking, (void *)DstShadow, dstPitch,
1465
+ (void *)SrcShadow, srcPitch, width, height, 0 ,
1466
+ nullptr , &hEvents[1 ]));
1467
+ } else if (DstInfoItOp) {
1468
+ auto DstInfo = (*DstInfoItOp)->second ;
1469
+
1470
+ const auto &DeviceInfo =
1471
+ getMsanInterceptor ()->getDeviceInfo (DstInfo->Device );
1472
+ const auto DstShadow = DeviceInfo->Shadow ->MemToShadow (Dst);
1473
+
1474
+ const char Pattern = 0 ;
1475
+ UR_CALL (getContext ()->urDdiTable .Enqueue .pfnUSMFill2D (
1476
+ hQueue, (void *)DstShadow, dstPitch, 1 , &Pattern, width, height, 0 ,
1477
+ nullptr , &hEvents[1 ]));
1478
+ }
1479
+
1480
+ if (phEvent) {
1481
+ UR_CALL (getContext ()->urDdiTable .Enqueue .pfnEventsWait (
1482
+ hQueue, 2 , hEvents, phEvent));
1483
+ }
1484
+
1485
+ return UR_RESULT_SUCCESS;
1486
+ }
1487
+
1237
1488
// /////////////////////////////////////////////////////////////////////////////
1238
1489
// / @brief Exported function for filling application's Global table
1239
1490
// / with current process' addresses
@@ -1391,6 +1642,10 @@ ur_result_t urGetEnqueueProcAddrTable(
1391
1642
pDdiTable->pfnMemUnmap = ur_sanitizer_layer::msan::urEnqueueMemUnmap;
1392
1643
pDdiTable->pfnKernelLaunch =
1393
1644
ur_sanitizer_layer::msan::urEnqueueKernelLaunch;
1645
+ pDdiTable->pfnUSMFill = ur_sanitizer_layer::msan::urEnqueueUSMFill;
1646
+ pDdiTable->pfnUSMMemcpy = ur_sanitizer_layer::msan::urEnqueueUSMMemcpy;
1647
+ pDdiTable->pfnUSMFill2D = ur_sanitizer_layer::msan::urEnqueueUSMFill2D;
1648
+ pDdiTable->pfnUSMMemcpy2D = ur_sanitizer_layer::msan::urEnqueueUSMMemcpy2D;
1394
1649
1395
1650
return result;
1396
1651
}
@@ -1408,6 +1663,7 @@ ur_result_t urGetUSMProcAddrTable(
1408
1663
ur_result_t result = UR_RESULT_SUCCESS;
1409
1664
1410
1665
pDdiTable->pfnDeviceAlloc = ur_sanitizer_layer::msan::urUSMDeviceAlloc;
1666
+ pDdiTable->pfnFree = ur_sanitizer_layer::msan::urUSMFree;
1411
1667
1412
1668
return result;
1413
1669
}
0 commit comments