11
11
#ifndef USM_POOL_MANAGER_HPP
12
12
#define USM_POOL_MANAGER_HPP 1
13
13
14
+ #include < ur_ddi.h>
15
+
14
16
#include " logger/ur_logger.hpp"
15
17
#include " umf_helpers.hpp"
16
18
#include " ur_api.h"
26
28
27
29
namespace usm {
28
30
31
+ namespace detail {
32
+ struct ddiTables {
33
+ ddiTables () {
34
+ auto ret =
35
+ urGetDeviceProcAddrTable (UR_API_VERSION_CURRENT, &deviceDdiTable);
36
+ if (ret != UR_RESULT_SUCCESS) {
37
+ throw ret;
38
+ }
39
+
40
+ ret =
41
+ urGetContextProcAddrTable (UR_API_VERSION_CURRENT, &contextDdiTable);
42
+ if (ret != UR_RESULT_SUCCESS) {
43
+ throw ret;
44
+ }
45
+ }
46
+ ur_device_dditable_t deviceDdiTable;
47
+ ur_context_dditable_t contextDdiTable;
48
+ };
49
+ } // namespace detail
50
+
29
51
// / @brief describes an internal USM pool instance.
30
52
struct pool_descriptor {
31
53
ur_usm_pool_handle_t poolHandle;
@@ -44,9 +66,12 @@ struct pool_descriptor {
44
66
45
67
static inline std::pair<ur_result_t , std::vector<ur_device_handle_t >>
46
68
urGetSubDevices (ur_device_handle_t hDevice) {
69
+ static detail::ddiTables ddi;
70
+
47
71
uint32_t nComputeUnits;
48
- auto ret = urDeviceGetInfo (hDevice, UR_DEVICE_INFO_MAX_COMPUTE_UNITS,
49
- sizeof (nComputeUnits), &nComputeUnits, nullptr );
72
+ auto ret = ddi.deviceDdiTable .pfnGetInfo (
73
+ hDevice, UR_DEVICE_INFO_MAX_COMPUTE_UNITS, sizeof (nComputeUnits),
74
+ &nComputeUnits, nullptr );
50
75
if (ret != UR_RESULT_SUCCESS) {
51
76
return {ret, {}};
52
77
}
@@ -64,15 +89,16 @@ urGetSubDevices(ur_device_handle_t hDevice) {
64
89
65
90
// Get the number of devices that will be created
66
91
uint32_t deviceCount;
67
- ret = urDevicePartition (hDevice, &properties, 0 , nullptr , &deviceCount);
92
+ ret = ddi.deviceDdiTable .pfnPartition (hDevice, &properties, 0 , nullptr ,
93
+ &deviceCount);
68
94
if (ret != UR_RESULT_SUCCESS) {
69
95
return {ret, {}};
70
96
}
71
97
72
98
std::vector<ur_device_handle_t > sub_devices (deviceCount);
73
- ret = urDevicePartition (hDevice, &properties,
74
- static_cast <uint32_t >(sub_devices.size ()),
75
- sub_devices.data (), nullptr );
99
+ ret = ddi. deviceDdiTable . pfnPartition (
100
+ hDevice, &properties, static_cast <uint32_t >(sub_devices.size ()),
101
+ sub_devices.data (), nullptr );
76
102
if (ret != UR_RESULT_SUCCESS) {
77
103
return {ret, {}};
78
104
}
@@ -82,17 +108,20 @@ urGetSubDevices(ur_device_handle_t hDevice) {
82
108
83
109
inline std::pair<ur_result_t , std::vector<ur_device_handle_t >>
84
110
urGetAllDevicesAndSubDevices (ur_context_handle_t hContext) {
111
+ static detail::ddiTables ddi;
112
+
85
113
size_t deviceCount = 0 ;
86
- auto ret = urContextGetInfo (hContext, UR_CONTEXT_INFO_NUM_DEVICES,
87
- sizeof (deviceCount), &deviceCount, nullptr );
114
+ auto ret = ddi.contextDdiTable .pfnGetInfo (
115
+ hContext, UR_CONTEXT_INFO_NUM_DEVICES, sizeof (deviceCount),
116
+ &deviceCount, nullptr );
88
117
if (ret != UR_RESULT_SUCCESS || deviceCount == 0 ) {
89
118
return {ret, {}};
90
119
}
91
120
92
121
std::vector<ur_device_handle_t > devices (deviceCount);
93
- ret = urContextGetInfo (hContext, UR_CONTEXT_INFO_DEVICES,
94
- sizeof ( ur_device_handle_t ) * deviceCount ,
95
- devices.data (), nullptr );
122
+ ret = ddi. contextDdiTable . pfnGetInfo (
123
+ hContext, UR_CONTEXT_INFO_DEVICES ,
124
+ sizeof ( ur_device_handle_t ) * deviceCount, devices.data (), nullptr );
96
125
if (ret != UR_RESULT_SUCCESS) {
97
126
return {ret, {}};
98
127
}
@@ -135,6 +164,8 @@ isSharedAllocationReadOnlyOnDevice(const pool_descriptor &desc) {
135
164
}
136
165
137
166
inline bool pool_descriptor::operator ==(const pool_descriptor &other) const {
167
+ static usm::detail::ddiTables ddi;
168
+
138
169
const pool_descriptor &lhs = *this ;
139
170
const pool_descriptor &rhs = other;
140
171
ur_native_handle_t lhsNative = 0 , rhsNative = 0 ;
@@ -145,14 +176,16 @@ inline bool pool_descriptor::operator==(const pool_descriptor &other) const {
145
176
// Ref: https://github.com/intel/llvm/commit/86511c5dc84b5781dcfd828caadcb5cac157eae1
146
177
// TODO: is this L0 specific?
147
178
if (lhs.hDevice ) {
148
- auto ret = urDeviceGetNativeHandle (lhs.hDevice , &lhsNative);
179
+ auto ret =
180
+ ddi.deviceDdiTable .pfnGetNativeHandle (lhs.hDevice , &lhsNative);
149
181
if (ret != UR_RESULT_SUCCESS) {
150
182
throw ret;
151
183
}
152
184
}
153
185
154
186
if (rhs.hDevice ) {
155
- auto ret = urDeviceGetNativeHandle (rhs.hDevice , &rhsNative);
187
+ auto ret =
188
+ ddi.deviceDdiTable .pfnGetNativeHandle (rhs.hDevice , &rhsNative);
156
189
if (ret != UR_RESULT_SUCCESS) {
157
190
throw ret;
158
191
}
@@ -264,9 +297,12 @@ namespace std {
264
297
// / @brief hash specialization for usm::pool_descriptor
265
298
template <> struct hash <usm::pool_descriptor> {
266
299
inline size_t operator ()(const usm::pool_descriptor &desc) const {
300
+ static usm::detail::ddiTables ddi;
301
+
267
302
ur_native_handle_t native = 0 ;
268
303
if (desc.hDevice ) {
269
- auto ret = urDeviceGetNativeHandle (desc.hDevice , &native);
304
+ auto ret =
305
+ ddi.deviceDdiTable .pfnGetNativeHandle (desc.hDevice , &native);
270
306
if (ret != UR_RESULT_SUCCESS) {
271
307
throw ret;
272
308
}
0 commit comments