11
11
#ifndef USM_POOL_MANAGER_HPP
12
12
#define USM_POOL_MANAGER_HPP 1
13
13
14
+ #include < ur_ddi.h>
15
+
14
16
#include " logger/ur_logger.hpp"
15
17
#include " umf_helpers.hpp"
16
18
#include " ur_api.h"
26
28
27
29
namespace usm {
28
30
31
+ namespace detail {
32
+ struct ddiTables {
33
+ ddiTables () {
34
+ auto ret =
35
+ urGetDeviceProcAddrTable (UR_API_VERSION_CURRENT, &deviceDdiTable);
36
+ if (ret != UR_RESULT_SUCCESS) {
37
+ throw ret;
38
+ }
39
+
40
+ ret =
41
+ urGetContextProcAddrTable (UR_API_VERSION_CURRENT, &contextDdiTable);
42
+ if (ret != UR_RESULT_SUCCESS) {
43
+ throw ret;
44
+ }
45
+ }
46
+ ur_device_dditable_t deviceDdiTable;
47
+ ur_context_dditable_t contextDdiTable;
48
+ };
49
+ } // namespace detail
50
+
29
51
// / @brief describes an internal USM pool instance.
30
52
struct pool_descriptor {
31
53
ur_usm_pool_handle_t poolHandle;
@@ -44,9 +66,12 @@ struct pool_descriptor {
44
66
45
67
static inline std::pair<ur_result_t , std::vector<ur_device_handle_t >>
46
68
urGetSubDevices (ur_device_handle_t hDevice) {
69
+ static detail::ddiTables ddi;
70
+
47
71
uint32_t nComputeUnits;
48
- auto ret = urDeviceGetInfo (hDevice, UR_DEVICE_INFO_MAX_COMPUTE_UNITS,
49
- sizeof (nComputeUnits), &nComputeUnits, nullptr );
72
+ auto ret = ddi.deviceDdiTable .pfnGetInfo (
73
+ hDevice, UR_DEVICE_INFO_MAX_COMPUTE_UNITS, sizeof (nComputeUnits),
74
+ &nComputeUnits, nullptr );
50
75
if (ret != UR_RESULT_SUCCESS) {
51
76
return {ret, {}};
52
77
}
@@ -64,15 +89,16 @@ urGetSubDevices(ur_device_handle_t hDevice) {
64
89
65
90
// Get the number of devices that will be created
66
91
uint32_t deviceCount;
67
- ret = urDevicePartition (hDevice, &properties, 0 , nullptr , &deviceCount);
92
+ ret = ddi.deviceDdiTable .pfnPartition (hDevice, &properties, 0 , nullptr ,
93
+ &deviceCount);
68
94
if (ret != UR_RESULT_SUCCESS) {
69
95
return {ret, {}};
70
96
}
71
97
72
98
std::vector<ur_device_handle_t > sub_devices (deviceCount);
73
- ret = urDevicePartition (hDevice, &properties,
74
- static_cast <uint32_t >(sub_devices.size ()),
75
- sub_devices.data (), nullptr );
99
+ ret = ddi. deviceDdiTable . pfnPartition (
100
+ hDevice, &properties, static_cast <uint32_t >(sub_devices.size ()),
101
+ sub_devices.data (), nullptr );
76
102
if (ret != UR_RESULT_SUCCESS) {
77
103
return {ret, {}};
78
104
}
@@ -82,17 +108,20 @@ urGetSubDevices(ur_device_handle_t hDevice) {
82
108
83
109
inline std::pair<ur_result_t , std::vector<ur_device_handle_t >>
84
110
urGetAllDevicesAndSubDevices (ur_context_handle_t hContext) {
111
+ static detail::ddiTables ddi;
112
+
85
113
size_t deviceCount = 0 ;
86
- auto ret = urContextGetInfo (hContext, UR_CONTEXT_INFO_NUM_DEVICES,
87
- sizeof (deviceCount), &deviceCount, nullptr );
114
+ auto ret = ddi.contextDdiTable .pfnGetInfo (
115
+ hContext, UR_CONTEXT_INFO_NUM_DEVICES, sizeof (deviceCount),
116
+ &deviceCount, nullptr );
88
117
if (ret != UR_RESULT_SUCCESS || deviceCount == 0 ) {
89
118
return {ret, {}};
90
119
}
91
120
92
121
std::vector<ur_device_handle_t > devices (deviceCount);
93
- ret = urContextGetInfo (hContext, UR_CONTEXT_INFO_DEVICES,
94
- sizeof ( ur_device_handle_t ) * deviceCount ,
95
- devices.data (), nullptr );
122
+ ret = ddi. contextDdiTable . pfnGetInfo (
123
+ hContext, UR_CONTEXT_INFO_DEVICES ,
124
+ sizeof ( ur_device_handle_t ) * deviceCount, devices.data (), nullptr );
96
125
if (ret != UR_RESULT_SUCCESS) {
97
126
return {ret, {}};
98
127
}
@@ -264,9 +293,12 @@ namespace std {
264
293
// / @brief hash specialization for usm::pool_descriptor
265
294
template <> struct hash <usm::pool_descriptor> {
266
295
inline size_t operator ()(const usm::pool_descriptor &desc) const {
296
+ static usm::detail::ddiTables ddi;
297
+
267
298
ur_native_handle_t native = 0 ;
268
299
if (desc.hDevice ) {
269
- auto ret = urDeviceGetNativeHandle (desc.hDevice , &native);
300
+ auto ret =
301
+ ddi.deviceDdiTable .pfnGetNativeHandle (desc.hDevice , &native);
270
302
if (ret != UR_RESULT_SUCCESS) {
271
303
throw ret;
272
304
}
0 commit comments