Skip to content

Commit 71dd495

Browse files
GeorgeWebkbenzie
authored andcommitted
Update urEnqueueUSMPrefetch entry point to always return a valid event and ignore flags
1 parent b20a877 commit 71dd495

File tree

1 file changed

+43
-33
lines changed

1 file changed

+43
-33
lines changed

source/adapters/hip/enqueue.cpp

Lines changed: 43 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1459,64 +1459,74 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMPrefetch(
14591459
ur_queue_handle_t hQueue, const void *pMem, size_t size,
14601460
ur_usm_migration_flags_t flags, uint32_t numEventsInWaitList,
14611461
const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
1462+
std::ignore = flags;
1463+
14621464
void *HIPDevicePtr = const_cast<void *>(pMem);
14631465
ur_device_handle_t Device = hQueue->getDevice();
14641466

1465-
// If the device does not support managed memory access, we can't set
1466-
// mem_advise.
1467-
if (!getAttribute(Device, hipDeviceAttributeManagedMemory)) {
1468-
setErrorMessage("mem_advise ignored as device does not support "
1469-
" managed memory access",
1470-
UR_RESULT_SUCCESS);
1471-
return UR_RESULT_ERROR_ADAPTER_SPECIFIC;
1472-
}
1473-
1474-
hipPointerAttribute_t attribs;
1475-
// TODO: hipPointerGetAttributes will fail if pMem is non-HIP allocated
1476-
// memory, as it is neither registered as host memory, nor into the address
1477-
// space for the current device, meaning the pMem ptr points to a
1478-
// system-allocated memory. This means we may need to check system-alloacted
1479-
// memory and handle the failure more gracefully.
1480-
UR_CHECK_ERROR(hipPointerGetAttributes(&attribs, pMem));
1481-
// async prefetch requires USM pointer (or hip SVM) to work.
1482-
if (!attribs.isManaged) {
1483-
setErrorMessage("Prefetch hint ignored as prefetch only works with USM",
1484-
UR_RESULT_SUCCESS);
1485-
return UR_RESULT_ERROR_ADAPTER_SPECIFIC;
1486-
}
1487-
1488-
// HIP_POINTER_ATTRIBUTE_RANGE_SIZE is not an attribute in ROCM < 5,
1489-
// so we can't perform this check for such cases.
1467+
// HIP_POINTER_ATTRIBUTE_RANGE_SIZE is not an attribute in ROCM < 5,
1468+
// so we can't perform this check for such cases.
14901469
#if HIP_VERSION_MAJOR >= 5
14911470
unsigned int PointerRangeSize = 0;
14921471
UR_CHECK_ERROR(hipPointerGetAttribute(&PointerRangeSize,
14931472
HIP_POINTER_ATTRIBUTE_RANGE_SIZE,
14941473
(hipDeviceptr_t)HIPDevicePtr));
14951474
UR_ASSERT(size <= PointerRangeSize, UR_RESULT_ERROR_INVALID_SIZE);
14961475
#endif
1497-
// flags is currently unused so fail if set
1498-
if (flags != 0)
1499-
return UR_RESULT_ERROR_INVALID_VALUE;
1476+
15001477
ur_result_t Result = UR_RESULT_SUCCESS;
1501-
std::unique_ptr<ur_event_handle_t_> EventPtr{nullptr};
15021478

15031479
try {
15041480
ScopedContext Active(hQueue->getDevice());
15051481
hipStream_t HIPStream = hQueue->getNextTransferStream();
15061482
Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
15071483
phEventWaitList);
1484+
1485+
std::unique_ptr<ur_event_handle_t_> EventPtr{nullptr};
1486+
15081487
if (phEvent) {
15091488
EventPtr =
15101489
std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative(
15111490
UR_COMMAND_USM_PREFETCH, hQueue, HIPStream));
15121491
UR_CHECK_ERROR(EventPtr->start());
15131492
}
1493+
1494+
// Helper to ensure returning a valid event on early exit.
1495+
auto releaseEvent = [&EventPtr, &phEvent]() -> void {
1496+
if (phEvent) {
1497+
UR_CHECK_ERROR(EventPtr->record());
1498+
*phEvent = EventPtr.release();
1499+
}
1500+
};
1501+
1502+
// If the device does not support managed memory access, we can't set
1503+
// mem_advise.
1504+
if (!getAttribute(Device, hipDeviceAttributeManagedMemory)) {
1505+
releaseEvent();
1506+
setErrorMessage("mem_advise ignored as device does not support "
1507+
"managed memory access",
1508+
UR_RESULT_SUCCESS);
1509+
return UR_RESULT_ERROR_ADAPTER_SPECIFIC;
1510+
}
1511+
1512+
hipPointerAttribute_t attribs;
1513+
// TODO: hipPointerGetAttributes will fail if pMem is non-HIP allocated
1514+
// memory, as it is neither registered as host memory, nor into the address
1515+
// space for the current device, meaning the pMem ptr points to a
1516+
// system-allocated memory. This means we may need to check system-alloacted
1517+
// memory and handle the failure more gracefully.
1518+
UR_CHECK_ERROR(hipPointerGetAttributes(&attribs, pMem));
1519+
// async prefetch requires USM pointer (or hip SVM) to work.
1520+
if (!attribs.isManaged) {
1521+
releaseEvent();
1522+
setErrorMessage("Prefetch hint ignored as prefetch only works with USM",
1523+
UR_RESULT_SUCCESS);
1524+
return UR_RESULT_ERROR_ADAPTER_SPECIFIC;
1525+
}
1526+
15141527
UR_CHECK_ERROR(
15151528
hipMemPrefetchAsync(pMem, size, hQueue->getDevice()->get(), HIPStream));
1516-
if (phEvent) {
1517-
UR_CHECK_ERROR(EventPtr->record());
1518-
*phEvent = EventPtr.release();
1519-
}
1529+
releaseEvent();
15201530
} catch (ur_result_t Err) {
15211531
Result = Err;
15221532
}

0 commit comments

Comments
 (0)