Skip to content

Commit 12a67f5

Browse files
authored
Merge pull request intel#1027 from GeorgeWeb/georgi/hip_memadvise
[SYCL][HIP] Implement mem_advise for HIP
2 parents 2b7b827 + c10968f commit 12a67f5

File tree

1 file changed

+196
-43
lines changed

1 file changed

+196
-43
lines changed

source/adapters/hip/enqueue.cpp

Lines changed: 196 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,62 @@ void simpleGuessLocalWorkSize(size_t *ThreadsPerBlock,
8484
--ThreadsPerBlock[0];
8585
}
8686
}
87+
88+
ur_result_t setHipMemAdvise(const void *DevPtr, const size_t Size,
89+
ur_usm_advice_flags_t URAdviceFlags,
90+
hipDevice_t Device) {
91+
// Handle unmapped memory advice flags
92+
if (URAdviceFlags &
93+
(UR_USM_ADVICE_FLAG_SET_NON_ATOMIC_MOSTLY |
94+
UR_USM_ADVICE_FLAG_CLEAR_NON_ATOMIC_MOSTLY |
95+
UR_USM_ADVICE_FLAG_BIAS_CACHED | UR_USM_ADVICE_FLAG_BIAS_UNCACHED)) {
96+
return UR_RESULT_ERROR_INVALID_ENUMERATION;
97+
}
98+
99+
using ur_to_hip_advice_t = std::pair<ur_usm_advice_flags_t, hipMemoryAdvise>;
100+
101+
static constexpr std::array<ur_to_hip_advice_t, 6>
102+
URToHIPMemAdviseDeviceFlags{
103+
std::make_pair(UR_USM_ADVICE_FLAG_SET_READ_MOSTLY,
104+
hipMemAdviseSetReadMostly),
105+
std::make_pair(UR_USM_ADVICE_FLAG_CLEAR_READ_MOSTLY,
106+
hipMemAdviseUnsetReadMostly),
107+
std::make_pair(UR_USM_ADVICE_FLAG_SET_PREFERRED_LOCATION,
108+
hipMemAdviseSetPreferredLocation),
109+
std::make_pair(UR_USM_ADVICE_FLAG_CLEAR_PREFERRED_LOCATION,
110+
hipMemAdviseUnsetPreferredLocation),
111+
std::make_pair(UR_USM_ADVICE_FLAG_SET_ACCESSED_BY_DEVICE,
112+
hipMemAdviseSetAccessedBy),
113+
std::make_pair(UR_USM_ADVICE_FLAG_CLEAR_ACCESSED_BY_DEVICE,
114+
hipMemAdviseUnsetAccessedBy),
115+
};
116+
for (auto &FlagPair : URToHIPMemAdviseDeviceFlags) {
117+
if (URAdviceFlags & FlagPair.first) {
118+
UR_CHECK_ERROR(hipMemAdvise(DevPtr, Size, FlagPair.second, Device));
119+
}
120+
}
121+
122+
static constexpr std::array<ur_to_hip_advice_t, 4> URToHIPMemAdviseHostFlags{
123+
std::make_pair(UR_USM_ADVICE_FLAG_SET_PREFERRED_LOCATION_HOST,
124+
hipMemAdviseSetPreferredLocation),
125+
std::make_pair(UR_USM_ADVICE_FLAG_CLEAR_PREFERRED_LOCATION_HOST,
126+
hipMemAdviseUnsetPreferredLocation),
127+
std::make_pair(UR_USM_ADVICE_FLAG_SET_ACCESSED_BY_HOST,
128+
hipMemAdviseSetAccessedBy),
129+
std::make_pair(UR_USM_ADVICE_FLAG_CLEAR_ACCESSED_BY_HOST,
130+
hipMemAdviseUnsetAccessedBy),
131+
};
132+
133+
for (auto &FlagPair : URToHIPMemAdviseHostFlags) {
134+
if (URAdviceFlags & FlagPair.first) {
135+
UR_CHECK_ERROR(
136+
hipMemAdvise(DevPtr, Size, FlagPair.second, hipCpuDeviceId));
137+
}
138+
}
139+
140+
return UR_RESULT_SUCCESS;
141+
}
142+
87143
} // namespace
88144

89145
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWrite(
@@ -1403,87 +1459,184 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMPrefetch(
14031459
ur_queue_handle_t hQueue, const void *pMem, size_t size,
14041460
ur_usm_migration_flags_t flags, uint32_t numEventsInWaitList,
14051461
const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
1462+
std::ignore = flags;
1463+
14061464
void *HIPDevicePtr = const_cast<void *>(pMem);
14071465
ur_device_handle_t Device = hQueue->getDevice();
14081466

1409-
// If the device does not support managed memory access, we can't set
1410-
// mem_advise.
1411-
if (!getAttribute(Device, hipDeviceAttributeManagedMemory)) {
1412-
setErrorMessage("mem_advise ignored as device does not support "
1413-
" managed memory access",
1414-
UR_RESULT_SUCCESS);
1415-
return UR_RESULT_ERROR_ADAPTER_SPECIFIC;
1416-
}
1417-
1418-
hipPointerAttribute_t attribs;
1419-
// TODO: hipPointerGetAttributes will fail if pMem is non-HIP allocated
1420-
// memory, as it is neither registered as host memory, nor into the address
1421-
// space for the current device, meaning the pMem ptr points to a
1422-
// system-allocated memory. This means we may need to check system-alloacted
1423-
// memory and handle the failure more gracefully.
1424-
UR_CHECK_ERROR(hipPointerGetAttributes(&attribs, pMem));
1425-
// async prefetch requires USM pointer (or hip SVM) to work.
1426-
if (!attribs.isManaged) {
1427-
setErrorMessage("Prefetch hint ignored as prefetch only works with USM",
1428-
UR_RESULT_SUCCESS);
1429-
return UR_RESULT_ERROR_ADAPTER_SPECIFIC;
1430-
}
1431-
1432-
// HIP_POINTER_ATTRIBUTE_RANGE_SIZE is not an attribute in ROCM < 5,
1433-
// so we can't perform this check for such cases.
1467+
// HIP_POINTER_ATTRIBUTE_RANGE_SIZE is not an attribute in ROCM < 5,
1468+
// so we can't perform this check for such cases.
14341469
#if HIP_VERSION_MAJOR >= 5
14351470
unsigned int PointerRangeSize = 0;
14361471
UR_CHECK_ERROR(hipPointerGetAttribute(&PointerRangeSize,
14371472
HIP_POINTER_ATTRIBUTE_RANGE_SIZE,
14381473
(hipDeviceptr_t)HIPDevicePtr));
14391474
UR_ASSERT(size <= PointerRangeSize, UR_RESULT_ERROR_INVALID_SIZE);
14401475
#endif
1441-
// flags is currently unused so fail if set
1442-
if (flags != 0)
1443-
return UR_RESULT_ERROR_INVALID_VALUE;
1476+
14441477
ur_result_t Result = UR_RESULT_SUCCESS;
1445-
std::unique_ptr<ur_event_handle_t_> EventPtr{nullptr};
14461478

14471479
try {
14481480
ScopedContext Active(hQueue->getDevice());
14491481
hipStream_t HIPStream = hQueue->getNextTransferStream();
14501482
Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
14511483
phEventWaitList);
1484+
1485+
std::unique_ptr<ur_event_handle_t_> EventPtr{nullptr};
1486+
14521487
if (phEvent) {
14531488
EventPtr =
14541489
std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative(
14551490
UR_COMMAND_USM_PREFETCH, hQueue, HIPStream));
14561491
UR_CHECK_ERROR(EventPtr->start());
14571492
}
1493+
1494+
// Helper to ensure returning a valid event on early exit.
1495+
auto releaseEvent = [&EventPtr, &phEvent]() -> void {
1496+
if (phEvent) {
1497+
UR_CHECK_ERROR(EventPtr->record());
1498+
*phEvent = EventPtr.release();
1499+
}
1500+
};
1501+
1502+
// If the device does not support managed memory access, we can't set
1503+
// mem_advise.
1504+
if (!getAttribute(Device, hipDeviceAttributeManagedMemory)) {
1505+
releaseEvent();
1506+
setErrorMessage("mem_advise ignored as device does not support "
1507+
"managed memory access",
1508+
UR_RESULT_SUCCESS);
1509+
return UR_RESULT_ERROR_ADAPTER_SPECIFIC;
1510+
}
1511+
1512+
hipPointerAttribute_t attribs;
1513+
// TODO: hipPointerGetAttributes will fail if pMem is non-HIP allocated
1514+
// memory, as it is neither registered as host memory, nor into the address
1515+
// space for the current device, meaning the pMem ptr points to a
1516+
// system-allocated memory. This means we may need to check system-alloacted
1517+
// memory and handle the failure more gracefully.
1518+
UR_CHECK_ERROR(hipPointerGetAttributes(&attribs, pMem));
1519+
// async prefetch requires USM pointer (or hip SVM) to work.
1520+
if (!attribs.isManaged) {
1521+
releaseEvent();
1522+
setErrorMessage("Prefetch hint ignored as prefetch only works with USM",
1523+
UR_RESULT_SUCCESS);
1524+
return UR_RESULT_ERROR_ADAPTER_SPECIFIC;
1525+
}
1526+
14581527
UR_CHECK_ERROR(
14591528
hipMemPrefetchAsync(pMem, size, hQueue->getDevice()->get(), HIPStream));
1460-
if (phEvent) {
1461-
UR_CHECK_ERROR(EventPtr->record());
1462-
*phEvent = EventPtr.release();
1463-
}
1529+
releaseEvent();
14641530
} catch (ur_result_t Err) {
14651531
Result = Err;
14661532
}
14671533

14681534
return Result;
14691535
}
14701536

1537+
/// USM: memadvise API to govern behavior of automatic migration mechanisms
14711538
UR_APIEXPORT ur_result_t UR_APICALL
14721539
urEnqueueUSMAdvise(ur_queue_handle_t hQueue, const void *pMem, size_t size,
1473-
ur_usm_advice_flags_t, ur_event_handle_t *phEvent) {
1540+
ur_usm_advice_flags_t advice, ur_event_handle_t *phEvent) {
1541+
UR_ASSERT(pMem && size > 0, UR_RESULT_ERROR_INVALID_VALUE);
14741542
void *HIPDevicePtr = const_cast<void *>(pMem);
1475-
// HIP_POINTER_ATTRIBUTE_RANGE_SIZE is not an attribute in ROCM < 5,
1476-
// so we can't perform this check for such cases.
1543+
ur_device_handle_t Device = hQueue->getDevice();
1544+
14771545
#if HIP_VERSION_MAJOR >= 5
1478-
unsigned int PointerRangeSize = 0;
1479-
UR_CHECK_ERROR(hipPointerGetAttribute(&PointerRangeSize,
1480-
HIP_POINTER_ATTRIBUTE_RANGE_SIZE,
1481-
(hipDeviceptr_t)HIPDevicePtr));
1546+
// NOTE: The hipPointerGetAttribute API is marked as beta, meaning, while this
1547+
// is feature complete, it is still open to changes and outstanding issues.
1548+
size_t PointerRangeSize = 0;
1549+
UR_CHECK_ERROR(hipPointerGetAttribute(
1550+
&PointerRangeSize, HIP_POINTER_ATTRIBUTE_RANGE_SIZE,
1551+
static_cast<hipDeviceptr_t>(HIPDevicePtr)));
14821552
UR_ASSERT(size <= PointerRangeSize, UR_RESULT_ERROR_INVALID_SIZE);
14831553
#endif
1484-
// TODO implement a mapping to hipMemAdvise once the expected behaviour
1485-
// of urEnqueueUSMAdvise is detailed in the USM extension
1486-
return urEnqueueEventsWait(hQueue, 0, nullptr, phEvent);
1554+
1555+
ur_result_t Result = UR_RESULT_SUCCESS;
1556+
1557+
try {
1558+
ScopedContext Active(Device);
1559+
std::unique_ptr<ur_event_handle_t_> EventPtr{nullptr};
1560+
1561+
if (phEvent) {
1562+
EventPtr =
1563+
std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative(
1564+
UR_COMMAND_USM_ADVISE, hQueue, hQueue->getNextTransferStream()));
1565+
EventPtr->start();
1566+
}
1567+
1568+
// Helper to ensure returning a valid event on early exit.
1569+
auto releaseEvent = [&EventPtr, &phEvent]() -> void {
1570+
if (phEvent) {
1571+
UR_CHECK_ERROR(EventPtr->record());
1572+
*phEvent = EventPtr.release();
1573+
}
1574+
};
1575+
1576+
// If the device does not support managed memory access, we can't set
1577+
// mem_advise.
1578+
if (!getAttribute(Device, hipDeviceAttributeManagedMemory)) {
1579+
releaseEvent();
1580+
setErrorMessage("mem_advise ignored as device does not support "
1581+
"managed memory access",
1582+
UR_RESULT_SUCCESS);
1583+
return UR_RESULT_ERROR_ADAPTER_SPECIFIC;
1584+
}
1585+
1586+
// Passing MEM_ADVICE_SET/MEM_ADVICE_CLEAR_PREFERRED_LOCATION to
1587+
// hipMemAdvise on a GPU device requires the GPU device to report a non-zero
1588+
// value for hipDeviceAttributeConcurrentManagedAccess. Therefore, ignore
1589+
// the mem advice if concurrent managed memory access is not available.
1590+
if (advice & (UR_USM_ADVICE_FLAG_SET_PREFERRED_LOCATION |
1591+
UR_USM_ADVICE_FLAG_CLEAR_PREFERRED_LOCATION |
1592+
UR_USM_ADVICE_FLAG_SET_ACCESSED_BY_DEVICE |
1593+
UR_USM_ADVICE_FLAG_CLEAR_ACCESSED_BY_DEVICE |
1594+
UR_USM_ADVICE_FLAG_DEFAULT)) {
1595+
if (!getAttribute(Device, hipDeviceAttributeConcurrentManagedAccess)) {
1596+
releaseEvent();
1597+
setErrorMessage("mem_advise ignored as device does not support "
1598+
"concurrent managed access",
1599+
UR_RESULT_SUCCESS);
1600+
return UR_RESULT_ERROR_ADAPTER_SPECIFIC;
1601+
}
1602+
1603+
// TODO: If pMem points to valid system-allocated pageable memory, we
1604+
// should check that the device also has the
1605+
// hipDeviceAttributePageableMemoryAccess property, so that a valid
1606+
// read-only copy can be created on the device. This also applies for
1607+
// UR_USM_MEM_ADVICE_SET/MEM_ADVICE_CLEAR_READ_MOSTLY.
1608+
}
1609+
1610+
const auto DeviceID = Device->get();
1611+
if (advice & UR_USM_ADVICE_FLAG_DEFAULT) {
1612+
UR_CHECK_ERROR(
1613+
hipMemAdvise(pMem, size, hipMemAdviseUnsetReadMostly, DeviceID));
1614+
UR_CHECK_ERROR(hipMemAdvise(
1615+
pMem, size, hipMemAdviseUnsetPreferredLocation, DeviceID));
1616+
UR_CHECK_ERROR(
1617+
hipMemAdvise(pMem, size, hipMemAdviseUnsetAccessedBy, DeviceID));
1618+
} else {
1619+
Result = setHipMemAdvise(HIPDevicePtr, size, advice, DeviceID);
1620+
// UR_RESULT_ERROR_INVALID_ENUMERATION is returned when using a valid but
1621+
// currently unmapped advice arguments as not supported by this platform.
1622+
// Therefore, warn the user instead of throwing and aborting the runtime.
1623+
if (Result == UR_RESULT_ERROR_INVALID_ENUMERATION) {
1624+
releaseEvent();
1625+
setErrorMessage("mem_advise is ignored as the advice argument is not "
1626+
"supported by this device",
1627+
UR_RESULT_SUCCESS);
1628+
return UR_RESULT_ERROR_ADAPTER_SPECIFIC;
1629+
}
1630+
}
1631+
1632+
releaseEvent();
1633+
} catch (ur_result_t err) {
1634+
Result = err;
1635+
} catch (...) {
1636+
Result = UR_RESULT_ERROR_UNKNOWN;
1637+
}
1638+
1639+
return Result;
14871640
}
14881641

14891642
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMFill2D(

0 commit comments

Comments
 (0)