@@ -88,6 +88,24 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGet(
88
88
return UR_RESULT_SUCCESS;
89
89
}
90
90
91
+ uint64_t calculateGlobalMemSize (ur_device_handle_t Device) {
92
+ // Cache GlobalMemSize
93
+ Device->ZeGlobalMemSize .Compute =
94
+ [Device](struct ze_global_memsize &GlobalMemSize) {
95
+ for (const auto &ZeDeviceMemoryExtProperty :
96
+ Device->ZeDeviceMemoryProperties ->second ) {
97
+ GlobalMemSize.value += ZeDeviceMemoryExtProperty.physicalSize ;
98
+ }
99
+ if (GlobalMemSize.value == 0 ) {
100
+ for (const auto &ZeDeviceMemoryProperty :
101
+ Device->ZeDeviceMemoryProperties ->first ) {
102
+ GlobalMemSize.value += ZeDeviceMemoryProperty.totalSize ;
103
+ }
104
+ }
105
+ };
106
+ return Device->ZeGlobalMemSize .operator ->()->value ;
107
+ }
108
+
91
109
UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetInfo (
92
110
ur_device_handle_t Device, // /< [in] handle of the device instance
93
111
ur_device_info_t ParamName, // /< [in] type of the info to retrieve
@@ -251,20 +269,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetInfo(
251
269
case UR_DEVICE_INFO_MAX_MEM_ALLOC_SIZE:
252
270
return ReturnValue (uint64_t {Device->ZeDeviceProperties ->maxMemAllocSize });
253
271
case UR_DEVICE_INFO_GLOBAL_MEM_SIZE: {
254
- uint64_t GlobalMemSize = 0 ;
255
272
// Support to read physicalSize depends on kernel,
256
273
// so fallback into reading totalSize if physicalSize
257
274
// is not available.
258
- for (const auto &ZeDeviceMemoryExtProperty :
259
- Device->ZeDeviceMemoryProperties ->second ) {
260
- GlobalMemSize += ZeDeviceMemoryExtProperty.physicalSize ;
261
- }
262
- if (GlobalMemSize == 0 ) {
263
- for (const auto &ZeDeviceMemoryProperty :
264
- Device->ZeDeviceMemoryProperties ->first ) {
265
- GlobalMemSize += ZeDeviceMemoryProperty.totalSize ;
266
- }
267
- }
275
+ uint64_t GlobalMemSize = calculateGlobalMemSize (Device);
268
276
return ReturnValue (uint64_t {GlobalMemSize});
269
277
}
270
278
case UR_DEVICE_INFO_LOCAL_MEM_SIZE:
@@ -637,6 +645,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetInfo(
637
645
static_cast <int32_t >(ZE_RESULT_ERROR_UNINITIALIZED));
638
646
return UR_RESULT_ERROR_ADAPTER_SPECIFIC;
639
647
}
648
+ // Calculate the global memory size as the max limit that can be reported as
649
+ // "free" memory for the user to allocate.
650
+ uint64_t GlobalMemSize = calculateGlobalMemSize (Device);
640
651
// Only report device memory which zeMemAllocDevice can allocate from.
641
652
// Currently this is only the one enumerated with ordinal 0.
642
653
uint64_t FreeMemory = 0 ;
@@ -661,7 +672,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetInfo(
661
672
}
662
673
}
663
674
}
664
- return ReturnValue (FreeMemory);
675
+ return ReturnValue (std::min (GlobalMemSize, FreeMemory) );
665
676
}
666
677
case UR_DEVICE_INFO_MEMORY_CLOCK_RATE: {
667
678
// If there are not any memory modules then return 0.
0 commit comments