@@ -96,7 +96,10 @@ struct AllocInfo {
96
96
97
97
// Global shared state for liboffload
98
98
struct OffloadContext ;
99
- static OffloadContext *OffloadContextVal;
99
+ // This pointer is non-null if and only if the context is valid and fully
100
+ // initialized
101
+ static std::atomic<OffloadContext *> OffloadContextVal;
102
+ std::mutex OffloadContextValMutex;
100
103
struct OffloadContext {
101
104
OffloadContext (OffloadContext &) = delete ;
102
105
OffloadContext (OffloadContext &&) = delete ;
@@ -107,6 +110,7 @@ struct OffloadContext {
107
110
bool ValidationEnabled = true ;
108
111
DenseMap<void *, AllocInfo> AllocInfoMap{};
109
112
SmallVector<ol_platform_impl_t , 4 > Platforms{};
113
+ size_t RefCount;
110
114
111
115
ol_device_handle_t HostDevice () {
112
116
// The host platform is always inserted last
@@ -145,20 +149,18 @@ constexpr ol_platform_backend_t pluginNameToBackend(StringRef Name) {
145
149
#define PLUGIN_TARGET (Name ) extern " C" GenericPluginTy *createPlugin_##Name();
146
150
#include " Shared/Targets.def"
147
151
148
- Error initPlugins () {
149
- auto *Context = new OffloadContext{};
150
-
152
+ Error initPlugins (OffloadContext &Context) {
151
153
// Attempt to create an instance of each supported plugin.
152
154
#define PLUGIN_TARGET (Name ) \
153
155
do { \
154
- Context-> Platforms .emplace_back (ol_platform_impl_t { \
156
+ Context. Platforms .emplace_back (ol_platform_impl_t { \
155
157
std::unique_ptr<GenericPluginTy>(createPlugin_##Name ()), \
156
158
pluginNameToBackend (#Name)}); \
157
159
} while (false );
158
160
#include " Shared/Targets.def"
159
161
160
162
// Preemptively initialize all devices in the plugin
161
- for (auto &Platform : Context-> Platforms ) {
163
+ for (auto &Platform : Context. Platforms ) {
162
164
// Do not use the host plugin - it isn't supported.
163
165
if (Platform.BackendType == OL_PLATFORM_BACKEND_UNKNOWN)
164
166
continue ;
@@ -178,31 +180,56 @@ Error initPlugins() {
178
180
}
179
181
180
182
// Add the special host device
181
- auto &HostPlatform = Context-> Platforms .emplace_back (
183
+ auto &HostPlatform = Context. Platforms .emplace_back (
182
184
ol_platform_impl_t {nullptr , OL_PLATFORM_BACKEND_HOST});
183
185
HostPlatform.Devices .emplace_back (-1 , nullptr , nullptr , InfoTreeNode{});
184
- Context->HostDevice ()->Platform = &HostPlatform;
185
-
186
- Context->TracingEnabled = std::getenv (" OFFLOAD_TRACE" );
187
- Context->ValidationEnabled = !std::getenv (" OFFLOAD_DISABLE_VALIDATION" );
186
+ Context.HostDevice ()->Platform = &HostPlatform;
188
187
189
- OffloadContextVal = Context;
188
+ Context.TracingEnabled = std::getenv (" OFFLOAD_TRACE" );
189
+ Context.ValidationEnabled = !std::getenv (" OFFLOAD_DISABLE_VALIDATION" );
190
190
191
191
return Plugin::success ();
192
192
}
193
193
194
- // TODO: We can properly reference count here and manage the resources in a more
195
- // clever way
196
194
Error olInit_impl () {
197
- static std::once_flag InitFlag;
198
- std::optional<Error> InitResult{};
199
- std::call_once (InitFlag, [&] { InitResult = initPlugins (); });
195
+ std::lock_guard<std::mutex> Lock{OffloadContextValMutex};
200
196
201
- if (InitResult)
202
- return std::move (*InitResult);
203
- return Error::success ();
197
+ if (isOffloadInitialized ()) {
198
+ OffloadContext::get ().RefCount ++;
199
+ return Plugin::success ();
200
+ }
201
+
202
+ // Use a temporary to ensure that entry points querying OffloadContextVal do
203
+ // not get a partially initialized context
204
+ auto *NewContext = new OffloadContext{};
205
+ Error InitResult = initPlugins (*NewContext);
206
+ OffloadContextVal.store (NewContext);
207
+ OffloadContext::get ().RefCount ++;
208
+
209
+ return InitResult;
210
+ }
211
+
212
+ Error olShutDown_impl () {
213
+ std::lock_guard<std::mutex> Lock{OffloadContextValMutex};
214
+
215
+ if (--OffloadContext::get ().RefCount != 0 )
216
+ return Error::success ();
217
+
218
+ llvm::Error Result = Error::success ();
219
+ auto *OldContext = OffloadContextVal.exchange (nullptr );
220
+
221
+ for (auto &P : OldContext->Platforms ) {
222
+ // Host plugin is nullptr and has no deinit
223
+ if (!P.Plugin )
224
+ continue ;
225
+
226
+ if (auto Res = P.Plugin ->deinit ())
227
+ Result = llvm::joinErrors (std::move (Result), std::move (Res));
228
+ }
229
+
230
+ delete OldContext;
231
+ return Result;
204
232
}
205
- Error olShutDown_impl () { return Error::success (); }
206
233
207
234
Error olGetPlatformInfoImplDetail (ol_platform_handle_t Platform,
208
235
ol_platform_info_t PropName, size_t PropSize,
0 commit comments