@@ -84,6 +84,62 @@ void simpleGuessLocalWorkSize(size_t *ThreadsPerBlock,
84
84
--ThreadsPerBlock[0 ];
85
85
}
86
86
}
87
+
88
+ ur_result_t setHipMemAdvise (const void *DevPtr, const size_t Size,
89
+ ur_usm_advice_flags_t URAdviceFlags,
90
+ hipDevice_t Device) {
91
+ // Handle unmapped memory advice flags
92
+ if (URAdviceFlags &
93
+ (UR_USM_ADVICE_FLAG_SET_NON_ATOMIC_MOSTLY |
94
+ UR_USM_ADVICE_FLAG_CLEAR_NON_ATOMIC_MOSTLY |
95
+ UR_USM_ADVICE_FLAG_BIAS_CACHED | UR_USM_ADVICE_FLAG_BIAS_UNCACHED)) {
96
+ return UR_RESULT_ERROR_INVALID_ENUMERATION;
97
+ }
98
+
99
+ using ur_to_hip_advice_t = std::pair<ur_usm_advice_flags_t , hipMemoryAdvise>;
100
+
101
+ static constexpr std::array<ur_to_hip_advice_t , 6 >
102
+ URToHIPMemAdviseDeviceFlags{
103
+ std::make_pair (UR_USM_ADVICE_FLAG_SET_READ_MOSTLY,
104
+ hipMemAdviseSetReadMostly),
105
+ std::make_pair (UR_USM_ADVICE_FLAG_CLEAR_READ_MOSTLY,
106
+ hipMemAdviseUnsetReadMostly),
107
+ std::make_pair (UR_USM_ADVICE_FLAG_SET_PREFERRED_LOCATION,
108
+ hipMemAdviseSetPreferredLocation),
109
+ std::make_pair (UR_USM_ADVICE_FLAG_CLEAR_PREFERRED_LOCATION,
110
+ hipMemAdviseUnsetPreferredLocation),
111
+ std::make_pair (UR_USM_ADVICE_FLAG_SET_ACCESSED_BY_DEVICE,
112
+ hipMemAdviseSetAccessedBy),
113
+ std::make_pair (UR_USM_ADVICE_FLAG_CLEAR_ACCESSED_BY_DEVICE,
114
+ hipMemAdviseUnsetAccessedBy),
115
+ };
116
+ for (auto &FlagPair : URToHIPMemAdviseDeviceFlags) {
117
+ if (URAdviceFlags & FlagPair.first ) {
118
+ UR_CHECK_ERROR (hipMemAdvise (DevPtr, Size, FlagPair.second , Device));
119
+ }
120
+ }
121
+
122
+ static constexpr std::array<ur_to_hip_advice_t , 4 > URToHIPMemAdviseHostFlags{
123
+ std::make_pair (UR_USM_ADVICE_FLAG_SET_PREFERRED_LOCATION_HOST,
124
+ hipMemAdviseSetPreferredLocation),
125
+ std::make_pair (UR_USM_ADVICE_FLAG_CLEAR_PREFERRED_LOCATION_HOST,
126
+ hipMemAdviseUnsetPreferredLocation),
127
+ std::make_pair (UR_USM_ADVICE_FLAG_SET_ACCESSED_BY_HOST,
128
+ hipMemAdviseSetAccessedBy),
129
+ std::make_pair (UR_USM_ADVICE_FLAG_CLEAR_ACCESSED_BY_HOST,
130
+ hipMemAdviseUnsetAccessedBy),
131
+ };
132
+
133
+ for (auto &FlagPair : URToHIPMemAdviseHostFlags) {
134
+ if (URAdviceFlags & FlagPair.first ) {
135
+ UR_CHECK_ERROR (
136
+ hipMemAdvise (DevPtr, Size, FlagPair.second , hipCpuDeviceId));
137
+ }
138
+ }
139
+
140
+ return UR_RESULT_SUCCESS;
141
+ }
142
+
87
143
} // namespace
88
144
89
145
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWrite (
@@ -1403,87 +1459,184 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMPrefetch(
1403
1459
ur_queue_handle_t hQueue, const void *pMem, size_t size,
1404
1460
ur_usm_migration_flags_t flags, uint32_t numEventsInWaitList,
1405
1461
const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
1462
+ std::ignore = flags;
1463
+
1406
1464
void *HIPDevicePtr = const_cast <void *>(pMem);
1407
1465
ur_device_handle_t Device = hQueue->getDevice ();
1408
1466
1409
- // If the device does not support managed memory access, we can't set
1410
- // mem_advise.
1411
- if (!getAttribute (Device, hipDeviceAttributeManagedMemory)) {
1412
- setErrorMessage (" mem_advise ignored as device does not support "
1413
- " managed memory access" ,
1414
- UR_RESULT_SUCCESS);
1415
- return UR_RESULT_ERROR_ADAPTER_SPECIFIC;
1416
- }
1417
-
1418
- hipPointerAttribute_t attribs;
1419
- // TODO: hipPointerGetAttributes will fail if pMem is non-HIP allocated
1420
- // memory, as it is neither registered as host memory, nor into the address
1421
- // space for the current device, meaning the pMem ptr points to a
1422
- // system-allocated memory. This means we may need to check system-alloacted
1423
- // memory and handle the failure more gracefully.
1424
- UR_CHECK_ERROR (hipPointerGetAttributes (&attribs, pMem));
1425
- // async prefetch requires USM pointer (or hip SVM) to work.
1426
- if (!attribs.isManaged ) {
1427
- setErrorMessage (" Prefetch hint ignored as prefetch only works with USM" ,
1428
- UR_RESULT_SUCCESS);
1429
- return UR_RESULT_ERROR_ADAPTER_SPECIFIC;
1430
- }
1431
-
1432
- // HIP_POINTER_ATTRIBUTE_RANGE_SIZE is not an attribute in ROCM < 5,
1433
- // so we can't perform this check for such cases.
1467
+ // HIP_POINTER_ATTRIBUTE_RANGE_SIZE is not an attribute in ROCM < 5,
1468
+ // so we can't perform this check for such cases.
1434
1469
#if HIP_VERSION_MAJOR >= 5
1435
1470
unsigned int PointerRangeSize = 0 ;
1436
1471
UR_CHECK_ERROR (hipPointerGetAttribute (&PointerRangeSize,
1437
1472
HIP_POINTER_ATTRIBUTE_RANGE_SIZE,
1438
1473
(hipDeviceptr_t)HIPDevicePtr));
1439
1474
UR_ASSERT (size <= PointerRangeSize, UR_RESULT_ERROR_INVALID_SIZE);
1440
1475
#endif
1441
- // flags is currently unused so fail if set
1442
- if (flags != 0 )
1443
- return UR_RESULT_ERROR_INVALID_VALUE;
1476
+
1444
1477
ur_result_t Result = UR_RESULT_SUCCESS;
1445
- std::unique_ptr<ur_event_handle_t_> EventPtr{nullptr };
1446
1478
1447
1479
try {
1448
1480
ScopedContext Active (hQueue->getDevice ());
1449
1481
hipStream_t HIPStream = hQueue->getNextTransferStream ();
1450
1482
Result = enqueueEventsWait (hQueue, HIPStream, numEventsInWaitList,
1451
1483
phEventWaitList);
1484
+
1485
+ std::unique_ptr<ur_event_handle_t_> EventPtr{nullptr };
1486
+
1452
1487
if (phEvent) {
1453
1488
EventPtr =
1454
1489
std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative (
1455
1490
UR_COMMAND_USM_PREFETCH, hQueue, HIPStream));
1456
1491
UR_CHECK_ERROR (EventPtr->start ());
1457
1492
}
1493
+
1494
+ // Helper to ensure returning a valid event on early exit.
1495
+ auto releaseEvent = [&EventPtr, &phEvent]() -> void {
1496
+ if (phEvent) {
1497
+ UR_CHECK_ERROR (EventPtr->record ());
1498
+ *phEvent = EventPtr.release ();
1499
+ }
1500
+ };
1501
+
1502
+ // If the device does not support managed memory access, we can't set
1503
+ // mem_advise.
1504
+ if (!getAttribute (Device, hipDeviceAttributeManagedMemory)) {
1505
+ releaseEvent ();
1506
+ setErrorMessage (" mem_advise ignored as device does not support "
1507
+ " managed memory access" ,
1508
+ UR_RESULT_SUCCESS);
1509
+ return UR_RESULT_ERROR_ADAPTER_SPECIFIC;
1510
+ }
1511
+
1512
+ hipPointerAttribute_t attribs;
1513
+ // TODO: hipPointerGetAttributes will fail if pMem is non-HIP allocated
1514
+ // memory, as it is neither registered as host memory, nor into the address
1515
+ // space for the current device, meaning the pMem ptr points to a
1516
+ // system-allocated memory. This means we may need to check system-alloacted
1517
+ // memory and handle the failure more gracefully.
1518
+ UR_CHECK_ERROR (hipPointerGetAttributes (&attribs, pMem));
1519
+ // async prefetch requires USM pointer (or hip SVM) to work.
1520
+ if (!attribs.isManaged ) {
1521
+ releaseEvent ();
1522
+ setErrorMessage (" Prefetch hint ignored as prefetch only works with USM" ,
1523
+ UR_RESULT_SUCCESS);
1524
+ return UR_RESULT_ERROR_ADAPTER_SPECIFIC;
1525
+ }
1526
+
1458
1527
UR_CHECK_ERROR (
1459
1528
hipMemPrefetchAsync (pMem, size, hQueue->getDevice ()->get (), HIPStream));
1460
- if (phEvent) {
1461
- UR_CHECK_ERROR (EventPtr->record ());
1462
- *phEvent = EventPtr.release ();
1463
- }
1529
+ releaseEvent ();
1464
1530
} catch (ur_result_t Err) {
1465
1531
Result = Err;
1466
1532
}
1467
1533
1468
1534
return Result;
1469
1535
}
1470
1536
1537
+ // / USM: memadvise API to govern behavior of automatic migration mechanisms
1471
1538
UR_APIEXPORT ur_result_t UR_APICALL
1472
1539
urEnqueueUSMAdvise (ur_queue_handle_t hQueue, const void *pMem, size_t size,
1473
- ur_usm_advice_flags_t , ur_event_handle_t *phEvent) {
1540
+ ur_usm_advice_flags_t advice, ur_event_handle_t *phEvent) {
1541
+ UR_ASSERT (pMem && size > 0 , UR_RESULT_ERROR_INVALID_VALUE);
1474
1542
void *HIPDevicePtr = const_cast <void *>(pMem);
1475
- // HIP_POINTER_ATTRIBUTE_RANGE_SIZE is not an attribute in ROCM < 5,
1476
- // so we can't perform this check for such cases.
1543
+ ur_device_handle_t Device = hQueue-> getDevice ();
1544
+
1477
1545
#if HIP_VERSION_MAJOR >= 5
1478
- unsigned int PointerRangeSize = 0 ;
1479
- UR_CHECK_ERROR (hipPointerGetAttribute (&PointerRangeSize,
1480
- HIP_POINTER_ATTRIBUTE_RANGE_SIZE,
1481
- (hipDeviceptr_t)HIPDevicePtr));
1546
+ // NOTE: The hipPointerGetAttribute API is marked as beta, meaning, while this
1547
+ // is feature complete, it is still open to changes and outstanding issues.
1548
+ size_t PointerRangeSize = 0 ;
1549
+ UR_CHECK_ERROR (hipPointerGetAttribute (
1550
+ &PointerRangeSize, HIP_POINTER_ATTRIBUTE_RANGE_SIZE,
1551
+ static_cast <hipDeviceptr_t>(HIPDevicePtr)));
1482
1552
UR_ASSERT (size <= PointerRangeSize, UR_RESULT_ERROR_INVALID_SIZE);
1483
1553
#endif
1484
- // TODO implement a mapping to hipMemAdvise once the expected behaviour
1485
- // of urEnqueueUSMAdvise is detailed in the USM extension
1486
- return urEnqueueEventsWait (hQueue, 0 , nullptr , phEvent);
1554
+
1555
+ ur_result_t Result = UR_RESULT_SUCCESS;
1556
+
1557
+ try {
1558
+ ScopedContext Active (Device);
1559
+ std::unique_ptr<ur_event_handle_t_> EventPtr{nullptr };
1560
+
1561
+ if (phEvent) {
1562
+ EventPtr =
1563
+ std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative (
1564
+ UR_COMMAND_USM_ADVISE, hQueue, hQueue->getNextTransferStream ()));
1565
+ EventPtr->start ();
1566
+ }
1567
+
1568
+ // Helper to ensure returning a valid event on early exit.
1569
+ auto releaseEvent = [&EventPtr, &phEvent]() -> void {
1570
+ if (phEvent) {
1571
+ UR_CHECK_ERROR (EventPtr->record ());
1572
+ *phEvent = EventPtr.release ();
1573
+ }
1574
+ };
1575
+
1576
+ // If the device does not support managed memory access, we can't set
1577
+ // mem_advise.
1578
+ if (!getAttribute (Device, hipDeviceAttributeManagedMemory)) {
1579
+ releaseEvent ();
1580
+ setErrorMessage (" mem_advise ignored as device does not support "
1581
+ " managed memory access" ,
1582
+ UR_RESULT_SUCCESS);
1583
+ return UR_RESULT_ERROR_ADAPTER_SPECIFIC;
1584
+ }
1585
+
1586
+ // Passing MEM_ADVICE_SET/MEM_ADVICE_CLEAR_PREFERRED_LOCATION to
1587
+ // hipMemAdvise on a GPU device requires the GPU device to report a non-zero
1588
+ // value for hipDeviceAttributeConcurrentManagedAccess. Therefore, ignore
1589
+ // the mem advice if concurrent managed memory access is not available.
1590
+ if (advice & (UR_USM_ADVICE_FLAG_SET_PREFERRED_LOCATION |
1591
+ UR_USM_ADVICE_FLAG_CLEAR_PREFERRED_LOCATION |
1592
+ UR_USM_ADVICE_FLAG_SET_ACCESSED_BY_DEVICE |
1593
+ UR_USM_ADVICE_FLAG_CLEAR_ACCESSED_BY_DEVICE |
1594
+ UR_USM_ADVICE_FLAG_DEFAULT)) {
1595
+ if (!getAttribute (Device, hipDeviceAttributeConcurrentManagedAccess)) {
1596
+ releaseEvent ();
1597
+ setErrorMessage (" mem_advise ignored as device does not support "
1598
+ " concurrent managed access" ,
1599
+ UR_RESULT_SUCCESS);
1600
+ return UR_RESULT_ERROR_ADAPTER_SPECIFIC;
1601
+ }
1602
+
1603
+ // TODO: If pMem points to valid system-allocated pageable memory, we
1604
+ // should check that the device also has the
1605
+ // hipDeviceAttributePageableMemoryAccess property, so that a valid
1606
+ // read-only copy can be created on the device. This also applies for
1607
+ // UR_USM_MEM_ADVICE_SET/MEM_ADVICE_CLEAR_READ_MOSTLY.
1608
+ }
1609
+
1610
+ const auto DeviceID = Device->get ();
1611
+ if (advice & UR_USM_ADVICE_FLAG_DEFAULT) {
1612
+ UR_CHECK_ERROR (
1613
+ hipMemAdvise (pMem, size, hipMemAdviseUnsetReadMostly, DeviceID));
1614
+ UR_CHECK_ERROR (hipMemAdvise (
1615
+ pMem, size, hipMemAdviseUnsetPreferredLocation, DeviceID));
1616
+ UR_CHECK_ERROR (
1617
+ hipMemAdvise (pMem, size, hipMemAdviseUnsetAccessedBy, DeviceID));
1618
+ } else {
1619
+ Result = setHipMemAdvise (HIPDevicePtr, size, advice, DeviceID);
1620
+ // UR_RESULT_ERROR_INVALID_ENUMERATION is returned when using a valid but
1621
+ // currently unmapped advice arguments as not supported by this platform.
1622
+ // Therefore, warn the user instead of throwing and aborting the runtime.
1623
+ if (Result == UR_RESULT_ERROR_INVALID_ENUMERATION) {
1624
+ releaseEvent ();
1625
+ setErrorMessage (" mem_advise is ignored as the advice argument is not "
1626
+ " supported by this device" ,
1627
+ UR_RESULT_SUCCESS);
1628
+ return UR_RESULT_ERROR_ADAPTER_SPECIFIC;
1629
+ }
1630
+ }
1631
+
1632
+ releaseEvent ();
1633
+ } catch (ur_result_t err) {
1634
+ Result = err;
1635
+ } catch (...) {
1636
+ Result = UR_RESULT_ERROR_UNKNOWN;
1637
+ }
1638
+
1639
+ return Result;
1487
1640
}
1488
1641
1489
1642
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMFill2D (
0 commit comments