@@ -82,6 +82,61 @@ void simpleGuessLocalWorkSize(size_t *ThreadsPerBlock,
82
82
--ThreadsPerBlock[0 ];
83
83
}
84
84
}
85
+
86
+ ur_result_t setHipMemAdvise (const void *DevPtr, size_t Size,
87
+ ur_usm_advice_flags_t URAdviceFlags,
88
+ hipDevice_t Device) {
89
+ std::unordered_map<ur_usm_advice_flags_t , hipMemoryAdvise>
90
+ URToHIPMemAdviseDeviceFlagsMap = {
91
+ {UR_USM_ADVICE_FLAG_SET_READ_MOSTLY, hipMemAdviseSetReadMostly},
92
+ {UR_USM_ADVICE_FLAG_CLEAR_READ_MOSTLY, hipMemAdviseUnsetReadMostly},
93
+ {UR_USM_ADVICE_FLAG_SET_PREFERRED_LOCATION,
94
+ hipMemAdviseSetPreferredLocation},
95
+ {UR_USM_ADVICE_FLAG_CLEAR_PREFERRED_LOCATION,
96
+ hipMemAdviseUnsetPreferredLocation},
97
+ {UR_USM_ADVICE_FLAG_SET_ACCESSED_BY_DEVICE,
98
+ hipMemAdviseSetAccessedBy},
99
+ {UR_USM_ADVICE_FLAG_CLEAR_ACCESSED_BY_DEVICE,
100
+ hipMemAdviseUnsetAccessedBy},
101
+ };
102
+ for (auto &FlagPair : URToHIPMemAdviseDeviceFlagsMap) {
103
+ if (URAdviceFlags & FlagPair.first ) {
104
+ UR_CHECK_ERROR (hipMemAdvise (DevPtr, Size, FlagPair.second , Device));
105
+ }
106
+ }
107
+
108
+ static std::unordered_map<ur_usm_advice_flags_t , hipMemoryAdvise>
109
+ URToHIPMemAdviseHostFlagsMap = {
110
+ {UR_USM_ADVICE_FLAG_SET_PREFERRED_LOCATION_HOST,
111
+ hipMemAdviseSetPreferredLocation},
112
+ {UR_USM_ADVICE_FLAG_CLEAR_PREFERRED_LOCATION_HOST,
113
+ hipMemAdviseUnsetPreferredLocation},
114
+ {UR_USM_ADVICE_FLAG_SET_ACCESSED_BY_HOST, hipMemAdviseSetAccessedBy},
115
+ {UR_USM_ADVICE_FLAG_CLEAR_ACCESSED_BY_HOST,
116
+ hipMemAdviseUnsetAccessedBy},
117
+ };
118
+
119
+ for (auto &FlagPair : URToHIPMemAdviseHostFlagsMap) {
120
+ if (URAdviceFlags & FlagPair.first ) {
121
+ UR_CHECK_ERROR (
122
+ hipMemAdvise (DevPtr, Size, FlagPair.second , hipCpuDeviceId));
123
+ }
124
+ }
125
+
126
+ static constexpr std::array<ur_usm_advice_flags_t , 4 > UnmappedMemAdviceFlags =
127
+ {UR_USM_ADVICE_FLAG_SET_NON_ATOMIC_MOSTLY,
128
+ UR_USM_ADVICE_FLAG_CLEAR_NON_ATOMIC_MOSTLY,
129
+ UR_USM_ADVICE_FLAG_BIAS_CACHED, UR_USM_ADVICE_FLAG_BIAS_UNCACHED};
130
+
131
+ for (auto &UnmappedFlag : UnmappedMemAdviceFlags) {
132
+ if (URAdviceFlags & UnmappedFlag) {
133
+ return UR_RESULT_ERROR_INVALID_ENUMERATION;
134
+ }
135
+ }
136
+
137
+ return UR_RESULT_SUCCESS;
138
+ }
139
+
85
140
} // namespace
86
141
87
142
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWrite (
@@ -1328,23 +1383,87 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMPrefetch(
1328
1383
#endif
1329
1384
}
1330
1385
1386
+ // / USM: memadvise API to govern behavior of automatic migration mechanisms
1331
1387
UR_APIEXPORT ur_result_t UR_APICALL
1332
1388
urEnqueueUSMAdvise (ur_queue_handle_t hQueue, const void *pMem, size_t size,
1333
- ur_usm_advice_flags_t , ur_event_handle_t *phEvent) {
1334
- # if HIP_VERSION_MAJOR >= 5
1389
+ ur_usm_advice_flags_t advice , ur_event_handle_t *phEvent) {
1390
+ UR_ASSERT (pMem && size > 0 , UR_RESULT_ERROR_INVALID_VALUE);
1335
1391
void *HIPDevicePtr = const_cast <void *>(pMem);
1392
+
1393
+ // Passing MEM_ADVISE_SET/MEM_ADVISE_CLEAR_PREFERRED_LOCATION and
1394
+ // to hipMemAdvise on a GPU device requires the GPU device to report a
1395
+ // non-zero value for hipDeviceAttributeConcurrentManagedAccess. Therfore,
1396
+ // ignore memory advise if concurrent managed memory access is not available.
1397
+ if ((advice & UR_USM_ADVICE_FLAG_SET_PREFERRED_LOCATION) ||
1398
+ (advice & UR_USM_ADVICE_FLAG_CLEAR_PREFERRED_LOCATION) ||
1399
+ (advice & UR_USM_ADVICE_FLAG_SET_ACCESSED_BY_DEVICE) ||
1400
+ (advice & UR_USM_ADVICE_FLAG_CLEAR_ACCESSED_BY_DEVICE) ||
1401
+ (advice & UR_USM_ADVICE_FLAG_DEFAULT)) {
1402
+ ur_device_handle_t Device = hQueue->getContext ()->getDevice ();
1403
+ if (!getAttribute (Device, hipDeviceAttributeConcurrentManagedAccess)) {
1404
+ setErrorMessage (" mem_advise ignored as device does not support "
1405
+ " concurrent managed access" ,
1406
+ UR_RESULT_SUCCESS);
1407
+ return UR_RESULT_ERROR_ADAPTER_SPECIFIC;
1408
+ }
1409
+
1410
+ // TODO: If pMem points to valid system-allocated pageable memory, we should
1411
+ // check that the device also has the hipDeviceAttributePageableMemoryAccess
1412
+ // property.
1413
+ }
1414
+
1336
1415
unsigned int PointerRangeSize = 0 ;
1337
- UR_CHECK_ERROR (hipPointerGetAttribute (&PointerRangeSize,
1338
- HIP_POINTER_ATTRIBUTE_RANGE_SIZE,
1339
- ( hipDeviceptr_t) HIPDevicePtr));
1416
+ UR_CHECK_ERROR (hipPointerGetAttribute (
1417
+ &PointerRangeSize, HIP_POINTER_ATTRIBUTE_RANGE_SIZE,
1418
+ static_cast < hipDeviceptr_t>( HIPDevicePtr) ));
1340
1419
UR_ASSERT (size <= PointerRangeSize, UR_RESULT_ERROR_INVALID_SIZE);
1341
1420
1342
- // TODO implement a mapping to hipMemAdvise once the expected behaviour
1343
- // of urEnqueueUSMAdvise is detailed in the USM extension
1344
- return urEnqueueEventsWait (hQueue, 0 , nullptr , phEvent);
1345
- #else
1346
- return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
1347
- #endif
1421
+ ur_result_t Result = UR_RESULT_SUCCESS;
1422
+ std::unique_ptr<ur_event_handle_t_> EventPtr{nullptr };
1423
+
1424
+ try {
1425
+ ScopedContext Active (hQueue->getDevice ());
1426
+
1427
+ if (phEvent) {
1428
+ EventPtr =
1429
+ std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative (
1430
+ UR_COMMAND_USM_ADVISE, hQueue, hQueue->getNextTransferStream ()));
1431
+ EventPtr->start ();
1432
+ }
1433
+
1434
+ if (advice & UR_USM_ADVICE_FLAG_DEFAULT) {
1435
+ UR_CHECK_ERROR (hipMemAdvise (pMem, size, hipMemAdviseUnsetReadMostly,
1436
+ hQueue->getContext ()->getDevice ()->get ()));
1437
+ UR_CHECK_ERROR (hipMemAdvise (pMem, size,
1438
+ hipMemAdviseUnsetPreferredLocation,
1439
+ hQueue->getContext ()->getDevice ()->get ()));
1440
+ UR_CHECK_ERROR (hipMemAdvise (pMem, size, hipMemAdviseUnsetAccessedBy,
1441
+ hQueue->getContext ()->getDevice ()->get ()));
1442
+ } else {
1443
+ Result = setHipMemAdvise (HIPDevicePtr, size, advice,
1444
+ hQueue->getContext ()->getDevice ()->get ());
1445
+ // UR_RESULT_ERROR_INVALID_ENUMERATION is returned when using a valid but
1446
+ // currently unmapped advice arguments as not supported by this platform.
1447
+ // Therefore, warn the user instead of throwing and aborting the runtime.
1448
+ if (Result == UR_RESULT_ERROR_INVALID_ENUMERATION) {
1449
+ setErrorMessage (" mem_advise is ignored as the advice argument is not "
1450
+ " supported by this device." ,
1451
+ Result);
1452
+ return UR_RESULT_ERROR_ADAPTER_SPECIFIC;
1453
+ }
1454
+ }
1455
+
1456
+ if (phEvent) {
1457
+ Result = EventPtr->record ();
1458
+ *phEvent = EventPtr.release ();
1459
+ }
1460
+ } catch (ur_result_t err) {
1461
+ Result = err;
1462
+ } catch (...) {
1463
+ Result = UR_RESULT_ERROR_UNKNOWN;
1464
+ }
1465
+
1466
+ return Result;
1348
1467
}
1349
1468
1350
1469
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMFill2D (
0 commit comments