Skip to content

Commit 003145d

Browse files
authored
[Offload] Implement olShutDown (#144055)
`olShutDown` was not properly calling deinit on the platforms, resulting in random segfaults on AMD devices. As part of this, `olInit` and `olShutDown` now alloc and free the offload context rather than it being static. This allows `olShutDown` to be called within a destructor of a static object (like the tests do) without having to worry about destructor ordering.
1 parent 6e6c61d commit 003145d

File tree

3 files changed

+61
-22
lines changed

3 files changed

+61
-22
lines changed

offload/liboffload/API/Common.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ def : Function {
176176
let desc = "Release the resources in use by Offload";
177177
let details = [
178178
"This decrements an internal reference count. When this reaches 0, all resources will be released",
179-
"Subsequent API calls made after this are not valid"
179+
"Subsequent API calls to methods other than `olInit` made after resources are released will return OL_ERRC_UNINITIALIZED"
180180
];
181181
let params = [];
182182
let returns = [];

offload/liboffload/src/OffloadImpl.cpp

Lines changed: 48 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,10 @@ struct AllocInfo {
9696

9797
// Global shared state for liboffload
9898
struct OffloadContext;
99-
static OffloadContext *OffloadContextVal;
99+
// This pointer is non-null if and only if the context is valid and fully
100+
// initialized
101+
static std::atomic<OffloadContext *> OffloadContextVal;
102+
std::mutex OffloadContextValMutex;
100103
struct OffloadContext {
101104
OffloadContext(OffloadContext &) = delete;
102105
OffloadContext(OffloadContext &&) = delete;
@@ -107,6 +110,7 @@ struct OffloadContext {
107110
bool ValidationEnabled = true;
108111
DenseMap<void *, AllocInfo> AllocInfoMap{};
109112
SmallVector<ol_platform_impl_t, 4> Platforms{};
113+
size_t RefCount;
110114

111115
ol_device_handle_t HostDevice() {
112116
// The host platform is always inserted last
@@ -145,20 +149,18 @@ constexpr ol_platform_backend_t pluginNameToBackend(StringRef Name) {
145149
#define PLUGIN_TARGET(Name) extern "C" GenericPluginTy *createPlugin_##Name();
146150
#include "Shared/Targets.def"
147151

148-
Error initPlugins() {
149-
auto *Context = new OffloadContext{};
150-
152+
Error initPlugins(OffloadContext &Context) {
151153
// Attempt to create an instance of each supported plugin.
152154
#define PLUGIN_TARGET(Name) \
153155
do { \
154-
Context->Platforms.emplace_back(ol_platform_impl_t{ \
156+
Context.Platforms.emplace_back(ol_platform_impl_t{ \
155157
std::unique_ptr<GenericPluginTy>(createPlugin_##Name()), \
156158
pluginNameToBackend(#Name)}); \
157159
} while (false);
158160
#include "Shared/Targets.def"
159161

160162
// Preemptively initialize all devices in the plugin
161-
for (auto &Platform : Context->Platforms) {
163+
for (auto &Platform : Context.Platforms) {
162164
// Do not use the host plugin - it isn't supported.
163165
if (Platform.BackendType == OL_PLATFORM_BACKEND_UNKNOWN)
164166
continue;
@@ -178,31 +180,56 @@ Error initPlugins() {
178180
}
179181

180182
// Add the special host device
181-
auto &HostPlatform = Context->Platforms.emplace_back(
183+
auto &HostPlatform = Context.Platforms.emplace_back(
182184
ol_platform_impl_t{nullptr, OL_PLATFORM_BACKEND_HOST});
183185
HostPlatform.Devices.emplace_back(-1, nullptr, nullptr, InfoTreeNode{});
184-
Context->HostDevice()->Platform = &HostPlatform;
185-
186-
Context->TracingEnabled = std::getenv("OFFLOAD_TRACE");
187-
Context->ValidationEnabled = !std::getenv("OFFLOAD_DISABLE_VALIDATION");
186+
Context.HostDevice()->Platform = &HostPlatform;
188187

189-
OffloadContextVal = Context;
188+
Context.TracingEnabled = std::getenv("OFFLOAD_TRACE");
189+
Context.ValidationEnabled = !std::getenv("OFFLOAD_DISABLE_VALIDATION");
190190

191191
return Plugin::success();
192192
}
193193

194-
// TODO: We can properly reference count here and manage the resources in a more
195-
// clever way
196194
Error olInit_impl() {
197-
static std::once_flag InitFlag;
198-
std::optional<Error> InitResult{};
199-
std::call_once(InitFlag, [&] { InitResult = initPlugins(); });
195+
std::lock_guard<std::mutex> Lock{OffloadContextValMutex};
200196

201-
if (InitResult)
202-
return std::move(*InitResult);
203-
return Error::success();
197+
if (isOffloadInitialized()) {
198+
OffloadContext::get().RefCount++;
199+
return Plugin::success();
200+
}
201+
202+
// Use a temporary to ensure that entry points querying OffloadContextVal do
203+
// not get a partially initialized context
204+
auto *NewContext = new OffloadContext{};
205+
Error InitResult = initPlugins(*NewContext);
206+
OffloadContextVal.store(NewContext);
207+
OffloadContext::get().RefCount++;
208+
209+
return InitResult;
210+
}
211+
212+
Error olShutDown_impl() {
213+
std::lock_guard<std::mutex> Lock{OffloadContextValMutex};
214+
215+
if (--OffloadContext::get().RefCount != 0)
216+
return Error::success();
217+
218+
llvm::Error Result = Error::success();
219+
auto *OldContext = OffloadContextVal.exchange(nullptr);
220+
221+
for (auto &P : OldContext->Platforms) {
222+
// Host plugin is nullptr and has no deinit
223+
if (!P.Plugin)
224+
continue;
225+
226+
if (auto Res = P.Plugin->deinit())
227+
Result = llvm::joinErrors(std::move(Result), std::move(Res));
228+
}
229+
230+
delete OldContext;
231+
return Result;
204232
}
205-
Error olShutDown_impl() { return Error::success(); }
206233

207234
Error olGetPlatformInfoImplDetail(ol_platform_handle_t Platform,
208235
ol_platform_info_t PropName, size_t PropSize,

offload/unittests/OffloadAPI/init/olInit.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,20 @@
1515

1616
struct olInitTest : ::testing::Test {};
1717

18+
TEST_F(olInitTest, Success) {
19+
ASSERT_SUCCESS(olInit());
20+
ASSERT_SUCCESS(olShutDown());
21+
}
22+
1823
TEST_F(olInitTest, Uninitialized) {
1924
ASSERT_ERROR(OL_ERRC_UNINITIALIZED,
2025
olIterateDevices(
2126
[](ol_device_handle_t, void *) { return false; }, nullptr));
2227
}
28+
29+
TEST_F(olInitTest, RepeatedInit) {
30+
for (size_t I = 0; I < 10; I++) {
31+
ASSERT_SUCCESS(olInit());
32+
ASSERT_SUCCESS(olShutDown());
33+
}
34+
}

0 commit comments

Comments
 (0)