Skip to content

Commit 53336ad

Browse files
authored
[Offload] Move (most) global state to an OffloadContext struct (#144494)
Rather than having a number of static local variables, we now use a single `OffloadContext` struct to store global state. This is initialised by `olInit`, but is never deleted (de-initialization of Offload isn't yet implemented). The error reporting mechanism has not been moved to the struct, since that's going to cause issues with teardown (error messages must outlive liboffload).
1 parent 9fd22cb commit 53336ad

File tree

4 files changed

+80
-56
lines changed

4 files changed

+80
-56
lines changed

offload/liboffload/include/OffloadImpl.hpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,12 @@
2222
#include "llvm/ADT/StringSet.h"
2323
#include "llvm/Support/Error.h"
2424

25-
struct OffloadConfig {
26-
bool TracingEnabled = false;
27-
bool ValidationEnabled = true;
28-
};
29-
30-
OffloadConfig &offloadConfig();
25+
namespace llvm {
26+
namespace offload {
27+
bool isTracingEnabled();
28+
bool isValidationEnabled();
29+
} // namespace offload
30+
} // namespace llvm
3131

3232
// Use the StringSet container to efficiently deduplicate repeated error
3333
// strings (e.g. if the same error is hit constantly in a long running program)

offload/liboffload/src/OffloadImpl.cpp

Lines changed: 51 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -93,22 +93,36 @@ struct AllocInfo {
9393
ol_alloc_type_t Type;
9494
};
9595

96-
using AllocInfoMapT = DenseMap<void *, AllocInfo>;
97-
AllocInfoMapT &allocInfoMap() {
98-
static AllocInfoMapT AllocInfoMap{};
99-
return AllocInfoMap;
100-
}
96+
// Global shared state for liboffload
97+
struct OffloadContext;
98+
static OffloadContext *OffloadContextVal;
99+
struct OffloadContext {
100+
OffloadContext(OffloadContext &) = delete;
101+
OffloadContext(OffloadContext &&) = delete;
102+
OffloadContext &operator=(OffloadContext &) = delete;
103+
OffloadContext &operator=(OffloadContext &&) = delete;
104+
105+
bool TracingEnabled = false;
106+
bool ValidationEnabled = true;
107+
DenseMap<void *, AllocInfo> AllocInfoMap{};
108+
SmallVector<ol_platform_impl_t, 4> Platforms{};
109+
110+
ol_device_handle_t HostDevice() {
111+
// The host platform is always inserted last
112+
return &Platforms.back().Devices[0];
113+
}
101114

102-
using PlatformVecT = SmallVector<ol_platform_impl_t, 4>;
103-
PlatformVecT &Platforms() {
104-
static PlatformVecT Platforms;
105-
return Platforms;
106-
}
115+
static OffloadContext &get() {
116+
assert(OffloadContextVal);
117+
return *OffloadContextVal;
118+
}
119+
};
107120

108-
ol_device_handle_t HostDevice() {
109-
// The host platform is always inserted last
110-
return &Platforms().back().Devices[0];
121+
// If the context is uninited, then we assume tracing is disabled
122+
bool isTracingEnabled() {
123+
return OffloadContextVal && OffloadContext::get().TracingEnabled;
111124
}
125+
bool isValidationEnabled() { return OffloadContext::get().ValidationEnabled; }
112126

113127
template <typename HandleT> Error olDestroy(HandleT Handle) {
114128
delete Handle;
@@ -130,18 +144,20 @@ constexpr ol_platform_backend_t pluginNameToBackend(StringRef Name) {
130144
#include "Shared/Targets.def"
131145

132146
void initPlugins() {
147+
auto *Context = new OffloadContext{};
148+
133149
// Attempt to create an instance of each supported plugin.
134150
#define PLUGIN_TARGET(Name) \
135151
do { \
136-
Platforms().emplace_back(ol_platform_impl_t{ \
152+
Context->Platforms.emplace_back(ol_platform_impl_t{ \
137153
std::unique_ptr<GenericPluginTy>(createPlugin_##Name()), \
138154
{}, \
139155
pluginNameToBackend(#Name)}); \
140156
} while (false);
141157
#include "Shared/Targets.def"
142158

143159
// Preemptively initialize all devices in the plugin
144-
for (auto &Platform : Platforms()) {
160+
for (auto &Platform : Context->Platforms) {
145161
// Do not use the host plugin - it isn't supported.
146162
if (Platform.BackendType == OL_PLATFORM_BACKEND_UNKNOWN)
147163
continue;
@@ -157,15 +173,16 @@ void initPlugins() {
157173
}
158174

159175
// Add the special host device
160-
auto &HostPlatform = Platforms().emplace_back(
176+
auto &HostPlatform = Context->Platforms.emplace_back(
161177
ol_platform_impl_t{nullptr,
162178
{ol_device_impl_t{-1, nullptr, nullptr}},
163179
OL_PLATFORM_BACKEND_HOST});
164-
HostDevice()->Platform = &HostPlatform;
180+
Context->HostDevice()->Platform = &HostPlatform;
181+
182+
Context->TracingEnabled = std::getenv("OFFLOAD_TRACE");
183+
Context->ValidationEnabled = !std::getenv("OFFLOAD_DISABLE_VALIDATION");
165184

166-
offloadConfig().TracingEnabled = std::getenv("OFFLOAD_TRACE");
167-
offloadConfig().ValidationEnabled =
168-
!std::getenv("OFFLOAD_DISABLE_VALIDATION");
185+
OffloadContextVal = Context;
169186
}
170187

171188
// TODO: We can properly reference count here and manage the resources in a more
@@ -229,7 +246,7 @@ Error olGetDeviceInfoImplDetail(ol_device_handle_t Device,
229246

230247
// Find the info if it exists under any of the given names
231248
auto GetInfo = [&](std::vector<std::string> Names) {
232-
if (Device == HostDevice())
249+
if (Device == OffloadContext::get().HostDevice())
233250
return std::string("Host");
234251

235252
if (!Device->Device)
@@ -251,8 +268,9 @@ Error olGetDeviceInfoImplDetail(ol_device_handle_t Device,
251268
case OL_DEVICE_INFO_PLATFORM:
252269
return ReturnValue(Device->Platform);
253270
case OL_DEVICE_INFO_TYPE:
254-
return Device == HostDevice() ? ReturnValue(OL_DEVICE_TYPE_HOST)
255-
: ReturnValue(OL_DEVICE_TYPE_GPU);
271+
return Device == OffloadContext::get().HostDevice()
272+
? ReturnValue(OL_DEVICE_TYPE_HOST)
273+
: ReturnValue(OL_DEVICE_TYPE_GPU);
256274
case OL_DEVICE_INFO_NAME:
257275
return ReturnValue(GetInfo({"Device Name"}).c_str());
258276
case OL_DEVICE_INFO_VENDOR:
@@ -280,7 +298,7 @@ Error olGetDeviceInfoSize_impl(ol_device_handle_t Device,
280298
}
281299

282300
Error olIterateDevices_impl(ol_device_iterate_cb_t Callback, void *UserData) {
283-
for (auto &Platform : Platforms()) {
301+
for (auto &Platform : OffloadContext::get().Platforms) {
284302
for (auto &Device : Platform.Devices) {
285303
if (!Callback(&Device, UserData)) {
286304
break;
@@ -311,24 +329,25 @@ Error olMemAlloc_impl(ol_device_handle_t Device, ol_alloc_type_t Type,
311329
return Alloc.takeError();
312330

313331
*AllocationOut = *Alloc;
314-
allocInfoMap().insert_or_assign(*Alloc, AllocInfo{Device, Type});
332+
OffloadContext::get().AllocInfoMap.insert_or_assign(*Alloc,
333+
AllocInfo{Device, Type});
315334
return Error::success();
316335
}
317336

318337
Error olMemFree_impl(void *Address) {
319-
if (!allocInfoMap().contains(Address))
338+
if (!OffloadContext::get().AllocInfoMap.contains(Address))
320339
return createOffloadError(ErrorCode::INVALID_ARGUMENT,
321340
"address is not a known allocation");
322341

323-
auto AllocInfo = allocInfoMap().at(Address);
342+
auto AllocInfo = OffloadContext::get().AllocInfoMap.at(Address);
324343
auto Device = AllocInfo.Device;
325344
auto Type = AllocInfo.Type;
326345

327346
if (auto Res =
328347
Device->Device->dataDelete(Address, convertOlToPluginAllocTy(Type)))
329348
return Res;
330349

331-
allocInfoMap().erase(Address);
350+
OffloadContext::get().AllocInfoMap.erase(Address);
332351

333352
return Error::success();
334353
}
@@ -395,7 +414,8 @@ Error olMemcpy_impl(ol_queue_handle_t Queue, void *DstPtr,
395414
ol_device_handle_t DstDevice, const void *SrcPtr,
396415
ol_device_handle_t SrcDevice, size_t Size,
397416
ol_event_handle_t *EventOut) {
398-
if (DstDevice == HostDevice() && SrcDevice == HostDevice()) {
417+
auto Host = OffloadContext::get().HostDevice();
418+
if (DstDevice == Host && SrcDevice == Host) {
399419
if (!Queue) {
400420
std::memcpy(DstPtr, SrcPtr, Size);
401421
return Error::success();
@@ -410,11 +430,11 @@ Error olMemcpy_impl(ol_queue_handle_t Queue, void *DstPtr,
410430
// If no queue is given the memcpy will be synchronous
411431
auto QueueImpl = Queue ? Queue->AsyncInfo : nullptr;
412432

413-
if (DstDevice == HostDevice()) {
433+
if (DstDevice == Host) {
414434
if (auto Res =
415435
SrcDevice->Device->dataRetrieve(DstPtr, SrcPtr, Size, QueueImpl))
416436
return Res;
417-
} else if (SrcDevice == HostDevice()) {
437+
} else if (SrcDevice == Host) {
418438
if (auto Res =
419439
DstDevice->Device->dataSubmit(DstPtr, SrcPtr, Size, QueueImpl))
420440
return Res;

offload/liboffload/src/OffloadLib.cpp

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,6 @@ ol_code_location_t *&currentCodeLocation() {
3030
return CodeLoc;
3131
}
3232

33-
OffloadConfig &offloadConfig() {
34-
static OffloadConfig Config{};
35-
return Config;
36-
}
37-
3833
namespace llvm {
3934
namespace offload {
4035
// Pull in the declarations for the implementation functions. The actual entry

offload/tools/offload-tblgen/EntryPointGen.cpp

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -35,21 +35,30 @@ static void EmitValidationFunc(const FunctionRec &F, raw_ostream &OS) {
3535
}
3636
OS << ") {\n";
3737

38-
OS << TAB_1 "if (offloadConfig().ValidationEnabled) {\n";
39-
// Emit validation checks
40-
for (const auto &Return : F.getReturns()) {
41-
for (auto &Condition : Return.getConditions()) {
42-
if (Condition.starts_with("`") && Condition.ends_with("`")) {
43-
auto ConditionString = Condition.substr(1, Condition.size() - 2);
44-
OS << formatv(TAB_2 "if ({0}) {{\n", ConditionString);
45-
OS << formatv(TAB_3 "return createOffloadError(error::ErrorCode::{0}, "
46-
"\"validation failure: {1}\");\n",
47-
Return.getUnprefixedValue(), ConditionString);
48-
OS << TAB_2 "}\n\n";
38+
bool HasValidation = llvm::any_of(F.getReturns(), [](auto &R) {
39+
return llvm::any_of(R.getConditions(), [](auto &C) {
40+
return C.starts_with("`") && C.ends_with("`");
41+
});
42+
});
43+
44+
if (HasValidation) {
45+
OS << TAB_1 "if (llvm::offload::isValidationEnabled()) {\n";
46+
// Emit validation checks
47+
for (const auto &Return : F.getReturns()) {
48+
for (auto &Condition : Return.getConditions()) {
49+
if (Condition.starts_with("`") && Condition.ends_with("`")) {
50+
auto ConditionString = Condition.substr(1, Condition.size() - 2);
51+
OS << formatv(TAB_2 "if ({0}) {{\n", ConditionString);
52+
OS << formatv(TAB_3
53+
"return createOffloadError(error::ErrorCode::{0}, "
54+
"\"validation failure: {1}\");\n",
55+
Return.getUnprefixedValue(), ConditionString);
56+
OS << TAB_2 "}\n\n";
57+
}
4958
}
5059
}
60+
OS << TAB_1 "}\n\n";
5161
}
52-
OS << TAB_1 "}\n\n";
5362

5463
// Perform actual function call to the implementation
5564
ParamNameList = ParamNameList.substr(0, ParamNameList.size() - 2);
@@ -74,7 +83,7 @@ static void EmitEntryPointFunc(const FunctionRec &F, raw_ostream &OS) {
7483
OS << ") {\n";
7584

7685
// Emit pre-call prints
77-
OS << TAB_1 "if (offloadConfig().TracingEnabled) {\n";
86+
OS << TAB_1 "if (llvm::offload::isTracingEnabled()) {\n";
7887
OS << formatv(TAB_2 "llvm::errs() << \"---> {0}\";\n", F.getName());
7988
OS << TAB_1 "}\n\n";
8089

@@ -85,7 +94,7 @@ static void EmitEntryPointFunc(const FunctionRec &F, raw_ostream &OS) {
8594
PrefixLower, F.getName(), ParamNameList);
8695

8796
// Emit post-call prints
88-
OS << TAB_1 "if (offloadConfig().TracingEnabled) {\n";
97+
OS << TAB_1 "if (llvm::offload::isTracingEnabled()) {\n";
8998
if (F.getParams().size() > 0) {
9099
OS << formatv(TAB_2 "{0} Params = {{", F.getParamStructName());
91100
for (const auto &Param : F.getParams()) {

0 commit comments

Comments
 (0)