Skip to content

Commit 9ecbbef

Browse files
committed
[common] Do not use loader APIs in ur_pool_manager
Calling loader APIs is incorrect - handles would have to be translated to and from loader handles. Also, using loader APIs without explictly linking with loaders results in linking failure on Windows. Fix this, by using function pointers.
1 parent 582c88e commit 9ecbbef

File tree

1 file changed

+50
-14
lines changed

1 file changed

+50
-14
lines changed

source/common/ur_pool_manager.hpp

Lines changed: 50 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
#ifndef USM_POOL_MANAGER_HPP
1212
#define USM_POOL_MANAGER_HPP 1
1313

14+
#include <ur_ddi.h>
15+
1416
#include "logger/ur_logger.hpp"
1517
#include "umf_helpers.hpp"
1618
#include "ur_api.h"
@@ -26,6 +28,26 @@
2628

2729
namespace usm {
2830

31+
namespace detail {
32+
struct ddiTables {
33+
ddiTables() {
34+
auto ret =
35+
urGetDeviceProcAddrTable(UR_API_VERSION_CURRENT, &deviceDdiTable);
36+
if (ret != UR_RESULT_SUCCESS) {
37+
throw ret;
38+
}
39+
40+
ret =
41+
urGetContextProcAddrTable(UR_API_VERSION_CURRENT, &contextDdiTable);
42+
if (ret != UR_RESULT_SUCCESS) {
43+
throw ret;
44+
}
45+
}
46+
ur_device_dditable_t deviceDdiTable;
47+
ur_context_dditable_t contextDdiTable;
48+
};
49+
} // namespace detail
50+
2951
/// @brief describes an internal USM pool instance.
3052
struct pool_descriptor {
3153
ur_usm_pool_handle_t poolHandle;
@@ -44,9 +66,12 @@ struct pool_descriptor {
4466

4567
static inline std::pair<ur_result_t, std::vector<ur_device_handle_t>>
4668
urGetSubDevices(ur_device_handle_t hDevice) {
69+
static detail::ddiTables ddi;
70+
4771
uint32_t nComputeUnits;
48-
auto ret = urDeviceGetInfo(hDevice, UR_DEVICE_INFO_MAX_COMPUTE_UNITS,
49-
sizeof(nComputeUnits), &nComputeUnits, nullptr);
72+
auto ret = ddi.deviceDdiTable.pfnGetInfo(
73+
hDevice, UR_DEVICE_INFO_MAX_COMPUTE_UNITS, sizeof(nComputeUnits),
74+
&nComputeUnits, nullptr);
5075
if (ret != UR_RESULT_SUCCESS) {
5176
return {ret, {}};
5277
}
@@ -64,15 +89,16 @@ urGetSubDevices(ur_device_handle_t hDevice) {
6489

6590
// Get the number of devices that will be created
6691
uint32_t deviceCount;
67-
ret = urDevicePartition(hDevice, &properties, 0, nullptr, &deviceCount);
92+
ret = ddi.deviceDdiTable.pfnPartition(hDevice, &properties, 0, nullptr,
93+
&deviceCount);
6894
if (ret != UR_RESULT_SUCCESS) {
6995
return {ret, {}};
7096
}
7197

7298
std::vector<ur_device_handle_t> sub_devices(deviceCount);
73-
ret = urDevicePartition(hDevice, &properties,
74-
static_cast<uint32_t>(sub_devices.size()),
75-
sub_devices.data(), nullptr);
99+
ret = ddi.deviceDdiTable.pfnPartition(
100+
hDevice, &properties, static_cast<uint32_t>(sub_devices.size()),
101+
sub_devices.data(), nullptr);
76102
if (ret != UR_RESULT_SUCCESS) {
77103
return {ret, {}};
78104
}
@@ -82,17 +108,20 @@ urGetSubDevices(ur_device_handle_t hDevice) {
82108

83109
inline std::pair<ur_result_t, std::vector<ur_device_handle_t>>
84110
urGetAllDevicesAndSubDevices(ur_context_handle_t hContext) {
111+
static detail::ddiTables ddi;
112+
85113
size_t deviceCount = 0;
86-
auto ret = urContextGetInfo(hContext, UR_CONTEXT_INFO_NUM_DEVICES,
87-
sizeof(deviceCount), &deviceCount, nullptr);
114+
auto ret = ddi.contextDdiTable.pfnGetInfo(
115+
hContext, UR_CONTEXT_INFO_NUM_DEVICES, sizeof(deviceCount),
116+
&deviceCount, nullptr);
88117
if (ret != UR_RESULT_SUCCESS || deviceCount == 0) {
89118
return {ret, {}};
90119
}
91120

92121
std::vector<ur_device_handle_t> devices(deviceCount);
93-
ret = urContextGetInfo(hContext, UR_CONTEXT_INFO_DEVICES,
94-
sizeof(ur_device_handle_t) * deviceCount,
95-
devices.data(), nullptr);
122+
ret = ddi.contextDdiTable.pfnGetInfo(
123+
hContext, UR_CONTEXT_INFO_DEVICES,
124+
sizeof(ur_device_handle_t) * deviceCount, devices.data(), nullptr);
96125
if (ret != UR_RESULT_SUCCESS) {
97126
return {ret, {}};
98127
}
@@ -135,6 +164,8 @@ isSharedAllocationReadOnlyOnDevice(const pool_descriptor &desc) {
135164
}
136165

137166
inline bool pool_descriptor::operator==(const pool_descriptor &other) const {
167+
static usm::detail::ddiTables ddi;
168+
138169
const pool_descriptor &lhs = *this;
139170
const pool_descriptor &rhs = other;
140171
ur_native_handle_t lhsNative = 0, rhsNative = 0;
@@ -145,14 +176,16 @@ inline bool pool_descriptor::operator==(const pool_descriptor &other) const {
145176
// Ref: https://github.com/intel/llvm/commit/86511c5dc84b5781dcfd828caadcb5cac157eae1
146177
// TODO: is this L0 specific?
147178
if (lhs.hDevice) {
148-
auto ret = urDeviceGetNativeHandle(lhs.hDevice, &lhsNative);
179+
auto ret =
180+
ddi.deviceDdiTable.pfnGetNativeHandle(lhs.hDevice, &lhsNative);
149181
if (ret != UR_RESULT_SUCCESS) {
150182
throw ret;
151183
}
152184
}
153185

154186
if (rhs.hDevice) {
155-
auto ret = urDeviceGetNativeHandle(rhs.hDevice, &rhsNative);
187+
auto ret =
188+
ddi.deviceDdiTable.pfnGetNativeHandle(rhs.hDevice, &rhsNative);
156189
if (ret != UR_RESULT_SUCCESS) {
157190
throw ret;
158191
}
@@ -264,9 +297,12 @@ namespace std {
264297
/// @brief hash specialization for usm::pool_descriptor
265298
template <> struct hash<usm::pool_descriptor> {
266299
inline size_t operator()(const usm::pool_descriptor &desc) const {
300+
static usm::detail::ddiTables ddi;
301+
267302
ur_native_handle_t native = 0;
268303
if (desc.hDevice) {
269-
auto ret = urDeviceGetNativeHandle(desc.hDevice, &native);
304+
auto ret =
305+
ddi.deviceDdiTable.pfnGetNativeHandle(desc.hDevice, &native);
270306
if (ret != UR_RESULT_SUCCESS) {
271307
throw ret;
272308
}

0 commit comments

Comments
 (0)