Skip to content

Commit 80b0648

Browse files
authored
[UR][Offload] Dynamically allocate adapter object and check for errors (#19046)
We allow `urAdapterRelease` to be called in a global destructor. To avoid any issues with destructor ordering, use `urAdapterGet` and `urAdapterRelease` to manage the offload adapter's lifetime rather than global init/fini. In addition, `adapter::init()` now actually checks and returns any error from iterating devices.
1 parent 9f559ae commit 80b0648

File tree

5 files changed

+37
-21
lines changed

5 files changed

+37
-21
lines changed

unified-runtime/source/adapters/offload/adapter.cpp

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,10 @@
1717
#include "device.hpp"
1818
#include "platform.hpp"
1919
#include "ur/ur.hpp"
20+
#include "ur2offload.hpp"
2021
#include "ur_api.h"
2122

22-
ur_adapter_handle_t_ Adapter{};
23+
ur_adapter_handle_t Adapter = nullptr;
2324

2425
// Initialize liboffload and perform the initial platform and device discovery
2526
ur_result_t ur_adapter_handle_t_::init() {
@@ -30,7 +31,7 @@ ur_result_t ur_adapter_handle_t_::init() {
3031
Res = olIterateDevices(
3132
[](ol_device_handle_t D, void *UserData) {
3233
auto *Platforms =
33-
reinterpret_cast<decltype(Adapter.Platforms) *>(UserData);
34+
reinterpret_cast<decltype(Adapter->Platforms) *>(UserData);
3435

3536
ol_platform_handle_t Platform;
3637
olGetDeviceInfo(D, OL_DEVICE_INFO_PLATFORM, sizeof(Platform),
@@ -39,7 +40,7 @@ ur_result_t ur_adapter_handle_t_::init() {
3940
olGetPlatformInfo(Platform, OL_PLATFORM_INFO_BACKEND, sizeof(Backend),
4041
&Backend);
4142
if (Backend == OL_PLATFORM_BACKEND_HOST) {
42-
Adapter.HostDevice = D;
43+
Adapter->HostDevice = D;
4344
} else if (Backend != OL_PLATFORM_BACKEND_UNKNOWN) {
4445
auto URPlatform =
4546
std::find_if(Platforms->begin(), Platforms->end(), [&](auto &P) {
@@ -57,37 +58,52 @@ ur_result_t ur_adapter_handle_t_::init() {
5758
}
5859
return false;
5960
},
60-
&Adapter.Platforms);
61+
&Adapter->Platforms);
6162

62-
(void)Res;
63-
64-
return UR_RESULT_SUCCESS;
63+
return offloadResultToUR(Res);
6564
}
6665

6766
UR_APIEXPORT ur_result_t UR_APICALL urAdapterGet(
6867
uint32_t, ur_adapter_handle_t *phAdapters, uint32_t *pNumAdapters) {
68+
std::mutex InitMutex{};
69+
6970
if (phAdapters) {
70-
if (++Adapter.RefCount == 1) {
71-
Adapter.init();
71+
std::lock_guard Guard{InitMutex};
72+
73+
// We explicitly only initialize the adapter when outputting it
74+
if (!Adapter) {
75+
Adapter = new ur_adapter_handle_t_{};
76+
auto Res = Adapter->init();
77+
if (Res) {
78+
delete Adapter;
79+
Adapter = nullptr;
80+
return Res;
81+
}
7282
}
73-
*phAdapters = &Adapter;
83+
Adapter->RefCount++;
84+
*phAdapters = Adapter;
7485
}
86+
7587
if (pNumAdapters) {
7688
*pNumAdapters = 1;
7789
}
90+
7891
return UR_RESULT_SUCCESS;
7992
}
8093

8194
UR_APIEXPORT ur_result_t UR_APICALL urAdapterRelease(ur_adapter_handle_t) {
82-
if (--Adapter.RefCount == 0) {
95+
// Doesn't need protecting by a lock - There is no way to reinitialize the
96+
// adapter after the final reference is released
97+
if (--Adapter->RefCount == 0) {
8398
// This can crash when tracing is enabled.
8499
// olShutDown();
100+
delete Adapter;
85101
};
86102
return UR_RESULT_SUCCESS;
87103
}
88104

89105
UR_APIEXPORT ur_result_t UR_APICALL urAdapterRetain(ur_adapter_handle_t) {
90-
Adapter.RefCount++;
106+
Adapter->RefCount++;
91107
return UR_RESULT_SUCCESS;
92108
}
93109

@@ -102,7 +118,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterGetInfo(ur_adapter_handle_t,
102118
case UR_ADAPTER_INFO_BACKEND:
103119
return ReturnValue(UR_BACKEND_OFFLOAD);
104120
case UR_ADAPTER_INFO_REFERENCE_COUNT:
105-
return ReturnValue(Adapter.RefCount.load());
121+
return ReturnValue(Adapter->RefCount.load());
106122
case UR_ADAPTER_INFO_VERSION:
107123
return ReturnValue(1);
108124
default:
@@ -124,15 +140,15 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterSetLoggerCallback(
124140
ur_adapter_handle_t, ur_logger_callback_t pfnLoggerCallback,
125141
void *pUserData, ur_logger_level_t level = UR_LOGGER_LEVEL_QUIET) {
126142

127-
Adapter.Logger.setCallbackSink(pfnLoggerCallback, pUserData, level);
143+
Adapter->Logger.setCallbackSink(pfnLoggerCallback, pUserData, level);
128144

129145
return UR_RESULT_SUCCESS;
130146
}
131147

132148
UR_APIEXPORT ur_result_t UR_APICALL
133149
urAdapterSetLoggerCallbackLevel(ur_adapter_handle_t, ur_logger_level_t level) {
134150

135-
Adapter.Logger.setCallbackLevel(level);
151+
Adapter->Logger.setCallbackLevel(level);
136152

137153
return UR_RESULT_SUCCESS;
138154
}

unified-runtime/source/adapters/offload/adapter.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,4 +29,4 @@ struct ur_adapter_handle_t_ : ur::offload::handle_base {
2929
ur_result_t init();
3030
};
3131

32-
extern ur_adapter_handle_t_ Adapter;
32+
extern ur_adapter_handle_t Adapter;

unified-runtime/source/adapters/offload/enqueue.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferRead(
112112
char *DevPtr =
113113
reinterpret_cast<char *>(std::get<BufferMem>(hBuffer->Mem).Ptr);
114114

115-
olMemcpy(hQueue->OffloadQueue, pDst, Adapter.HostDevice, DevPtr + offset,
115+
olMemcpy(hQueue->OffloadQueue, pDst, Adapter->HostDevice, DevPtr + offset,
116116
hQueue->OffloadDevice, size, phEvent ? &EventOut : nullptr);
117117

118118
if (blockingRead) {
@@ -145,7 +145,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWrite(
145145

146146
auto Res =
147147
olMemcpy(hQueue->OffloadQueue, DevPtr + offset, hQueue->OffloadDevice,
148-
pSrc, Adapter.HostDevice, size, phEvent ? &EventOut : nullptr);
148+
pSrc, Adapter->HostDevice, size, phEvent ? &EventOut : nullptr);
149149
if (Res) {
150150
return offloadResultToUR(Res);
151151
}

unified-runtime/source/adapters/offload/memory.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemBufferCreate(
6060

6161
if (PerformInitialCopy) {
6262
auto Res = olMemcpy(nullptr, Ptr, OffloadDevice, HostPtr,
63-
Adapter.HostDevice, size, nullptr);
63+
Adapter->HostDevice, size, nullptr);
6464
if (Res) {
6565
return offloadResultToUR(Res);
6666
}

unified-runtime/source/adapters/offload/platform.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,12 @@ urPlatformGet(ur_adapter_handle_t, uint32_t NumEntries,
2222
ur_platform_handle_t *phPlatforms, uint32_t *pNumPlatforms) {
2323

2424
if (pNumPlatforms) {
25-
*pNumPlatforms = Adapter.Platforms.size();
25+
*pNumPlatforms = Adapter->Platforms.size();
2626
}
2727

2828
if (phPlatforms) {
2929
size_t PlatformIndex = 0;
30-
for (auto &Platform : Adapter.Platforms) {
30+
for (auto &Platform : Adapter->Platforms) {
3131
phPlatforms[PlatformIndex++] = Platform.get();
3232
if (PlatformIndex == NumEntries) {
3333
break;

0 commit comments

Comments
 (0)