-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[Offload] Move (most) global state to an OffloadContext
struct
#144494
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Rather than having a number of static local variables, we now use a single `OffloadContext` struct to store global state. This is initialised by `olInit`, but is never deleted (de-initialization of Offload isn't yet implemented). The error reporting mechanism has not been moved to the struct, since that's going to cause issues with teardown (error messages must outlive liboffload).
The global state stuff looks good to me, but I wonder if we can make tracing of |
@callumfare I've just updated it so that |
@jhuber6 Mind having a look at this? Also, for some reason the notification/tagging bot didn't see this. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A good start.
OffloadContext(OffloadContext &) = delete; | ||
OffloadContext(OffloadContext &&) = delete; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You probably want to delete everything except the default constructor..
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've added "deletes" for operator=
, anything else you want me to add?
In the future, olShutDown will need to destruct this object, so I don't think ~OffloadContext
could be deleted.
@llvm/pr-subscribers-offload Author: Ross Brunton (RossBrunton) ChangesRather than having a number of static local variables, we now use The error reporting mechanism has not been moved to the struct, since Full diff: https://github.com/llvm/llvm-project/pull/144494.diff 4 Files Affected:
diff --git a/offload/liboffload/include/OffloadImpl.hpp b/offload/liboffload/include/OffloadImpl.hpp
index 9b0a21cb9ae12..a12d8c47a180b 100644
--- a/offload/liboffload/include/OffloadImpl.hpp
+++ b/offload/liboffload/include/OffloadImpl.hpp
@@ -22,12 +22,12 @@
#include "llvm/ADT/StringSet.h"
#include "llvm/Support/Error.h"
-struct OffloadConfig {
- bool TracingEnabled = false;
- bool ValidationEnabled = true;
-};
-
-OffloadConfig &offloadConfig();
+namespace llvm {
+namespace offload {
+bool isTracingEnabled();
+bool isValidationEnabled();
+} // namespace offload
+} // namespace llvm
// Use the StringSet container to efficiently deduplicate repeated error
// strings (e.g. if the same error is hit constantly in a long running program)
diff --git a/offload/liboffload/src/OffloadImpl.cpp b/offload/liboffload/src/OffloadImpl.cpp
index 770c212d804d2..f02497c0a6331 100644
--- a/offload/liboffload/src/OffloadImpl.cpp
+++ b/offload/liboffload/src/OffloadImpl.cpp
@@ -93,22 +93,36 @@ struct AllocInfo {
ol_alloc_type_t Type;
};
-using AllocInfoMapT = DenseMap<void *, AllocInfo>;
-AllocInfoMapT &allocInfoMap() {
- static AllocInfoMapT AllocInfoMap{};
- return AllocInfoMap;
-}
+// Global shared state for liboffload
+struct OffloadContext;
+static OffloadContext *OffloadContextVal;
+struct OffloadContext {
+ OffloadContext(OffloadContext &) = delete;
+ OffloadContext(OffloadContext &&) = delete;
+ OffloadContext &operator=(OffloadContext &) = delete;
+ OffloadContext &operator=(OffloadContext &&) = delete;
+
+ bool TracingEnabled = false;
+ bool ValidationEnabled = true;
+ DenseMap<void *, AllocInfo> AllocInfoMap{};
+ SmallVector<ol_platform_impl_t, 4> Platforms{};
+
+ ol_device_handle_t HostDevice() {
+ // The host platform is always inserted last
+ return &Platforms.back().Devices[0];
+ }
-using PlatformVecT = SmallVector<ol_platform_impl_t, 4>;
-PlatformVecT &Platforms() {
- static PlatformVecT Platforms;
- return Platforms;
-}
+ static OffloadContext &get() {
+ assert(OffloadContextVal);
+ return *OffloadContextVal;
+ }
+};
-ol_device_handle_t HostDevice() {
- // The host platform is always inserted last
- return &Platforms().back().Devices[0];
+// If the context is uninited, then we assume tracing is disabled
+bool isTracingEnabled() {
+ return OffloadContextVal && OffloadContext::get().TracingEnabled;
}
+bool isValidationEnabled() { return OffloadContext::get().ValidationEnabled; }
template <typename HandleT> Error olDestroy(HandleT Handle) {
delete Handle;
@@ -130,10 +144,12 @@ constexpr ol_platform_backend_t pluginNameToBackend(StringRef Name) {
#include "Shared/Targets.def"
void initPlugins() {
+ auto *Context = new OffloadContext{};
+
// Attempt to create an instance of each supported plugin.
#define PLUGIN_TARGET(Name) \
do { \
- Platforms().emplace_back(ol_platform_impl_t{ \
+ Context->Platforms.emplace_back(ol_platform_impl_t{ \
std::unique_ptr<GenericPluginTy>(createPlugin_##Name()), \
{}, \
pluginNameToBackend(#Name)}); \
@@ -141,7 +157,7 @@ void initPlugins() {
#include "Shared/Targets.def"
// Preemptively initialize all devices in the plugin
- for (auto &Platform : Platforms()) {
+ for (auto &Platform : Context->Platforms) {
// Do not use the host plugin - it isn't supported.
if (Platform.BackendType == OL_PLATFORM_BACKEND_UNKNOWN)
continue;
@@ -157,15 +173,16 @@ void initPlugins() {
}
// Add the special host device
- auto &HostPlatform = Platforms().emplace_back(
+ auto &HostPlatform = Context->Platforms.emplace_back(
ol_platform_impl_t{nullptr,
{ol_device_impl_t{-1, nullptr, nullptr}},
OL_PLATFORM_BACKEND_HOST});
- HostDevice()->Platform = &HostPlatform;
+ Context->HostDevice()->Platform = &HostPlatform;
+
+ Context->TracingEnabled = std::getenv("OFFLOAD_TRACE");
+ Context->ValidationEnabled = !std::getenv("OFFLOAD_DISABLE_VALIDATION");
- offloadConfig().TracingEnabled = std::getenv("OFFLOAD_TRACE");
- offloadConfig().ValidationEnabled =
- !std::getenv("OFFLOAD_DISABLE_VALIDATION");
+ OffloadContextVal = Context;
}
// TODO: We can properly reference count here and manage the resources in a more
@@ -229,7 +246,7 @@ Error olGetDeviceInfoImplDetail(ol_device_handle_t Device,
// Find the info if it exists under any of the given names
auto GetInfo = [&](std::vector<std::string> Names) {
- if (Device == HostDevice())
+ if (Device == OffloadContext::get().HostDevice())
return std::string("Host");
if (!Device->Device)
@@ -251,8 +268,9 @@ Error olGetDeviceInfoImplDetail(ol_device_handle_t Device,
case OL_DEVICE_INFO_PLATFORM:
return ReturnValue(Device->Platform);
case OL_DEVICE_INFO_TYPE:
- return Device == HostDevice() ? ReturnValue(OL_DEVICE_TYPE_HOST)
- : ReturnValue(OL_DEVICE_TYPE_GPU);
+ return Device == OffloadContext::get().HostDevice()
+ ? ReturnValue(OL_DEVICE_TYPE_HOST)
+ : ReturnValue(OL_DEVICE_TYPE_GPU);
case OL_DEVICE_INFO_NAME:
return ReturnValue(GetInfo({"Device Name"}).c_str());
case OL_DEVICE_INFO_VENDOR:
@@ -280,7 +298,7 @@ Error olGetDeviceInfoSize_impl(ol_device_handle_t Device,
}
Error olIterateDevices_impl(ol_device_iterate_cb_t Callback, void *UserData) {
- for (auto &Platform : Platforms()) {
+ for (auto &Platform : OffloadContext::get().Platforms) {
for (auto &Device : Platform.Devices) {
if (!Callback(&Device, UserData)) {
break;
@@ -311,16 +329,17 @@ Error olMemAlloc_impl(ol_device_handle_t Device, ol_alloc_type_t Type,
return Alloc.takeError();
*AllocationOut = *Alloc;
- allocInfoMap().insert_or_assign(*Alloc, AllocInfo{Device, Type});
+ OffloadContext::get().AllocInfoMap.insert_or_assign(*Alloc,
+ AllocInfo{Device, Type});
return Error::success();
}
Error olMemFree_impl(void *Address) {
- if (!allocInfoMap().contains(Address))
+ if (!OffloadContext::get().AllocInfoMap.contains(Address))
return createOffloadError(ErrorCode::INVALID_ARGUMENT,
"address is not a known allocation");
- auto AllocInfo = allocInfoMap().at(Address);
+ auto AllocInfo = OffloadContext::get().AllocInfoMap.at(Address);
auto Device = AllocInfo.Device;
auto Type = AllocInfo.Type;
@@ -328,7 +347,7 @@ Error olMemFree_impl(void *Address) {
Device->Device->dataDelete(Address, convertOlToPluginAllocTy(Type)))
return Res;
- allocInfoMap().erase(Address);
+ OffloadContext::get().AllocInfoMap.erase(Address);
return Error::success();
}
@@ -395,7 +414,8 @@ Error olMemcpy_impl(ol_queue_handle_t Queue, void *DstPtr,
ol_device_handle_t DstDevice, const void *SrcPtr,
ol_device_handle_t SrcDevice, size_t Size,
ol_event_handle_t *EventOut) {
- if (DstDevice == HostDevice() && SrcDevice == HostDevice()) {
+ auto Host = OffloadContext::get().HostDevice();
+ if (DstDevice == Host && SrcDevice == Host) {
if (!Queue) {
std::memcpy(DstPtr, SrcPtr, Size);
return Error::success();
@@ -410,11 +430,11 @@ Error olMemcpy_impl(ol_queue_handle_t Queue, void *DstPtr,
// If no queue is given the memcpy will be synchronous
auto QueueImpl = Queue ? Queue->AsyncInfo : nullptr;
- if (DstDevice == HostDevice()) {
+ if (DstDevice == Host) {
if (auto Res =
SrcDevice->Device->dataRetrieve(DstPtr, SrcPtr, Size, QueueImpl))
return Res;
- } else if (SrcDevice == HostDevice()) {
+ } else if (SrcDevice == Host) {
if (auto Res =
DstDevice->Device->dataSubmit(DstPtr, SrcPtr, Size, QueueImpl))
return Res;
diff --git a/offload/liboffload/src/OffloadLib.cpp b/offload/liboffload/src/OffloadLib.cpp
index 8662d3a44124b..0a65815e59698 100644
--- a/offload/liboffload/src/OffloadLib.cpp
+++ b/offload/liboffload/src/OffloadLib.cpp
@@ -30,11 +30,6 @@ ol_code_location_t *¤tCodeLocation() {
return CodeLoc;
}
-OffloadConfig &offloadConfig() {
- static OffloadConfig Config{};
- return Config;
-}
-
namespace llvm {
namespace offload {
// Pull in the declarations for the implementation functions. The actual entry
diff --git a/offload/tools/offload-tblgen/EntryPointGen.cpp b/offload/tools/offload-tblgen/EntryPointGen.cpp
index 85c5c50bf2f20..13aa0d1f63187 100644
--- a/offload/tools/offload-tblgen/EntryPointGen.cpp
+++ b/offload/tools/offload-tblgen/EntryPointGen.cpp
@@ -35,21 +35,30 @@ static void EmitValidationFunc(const FunctionRec &F, raw_ostream &OS) {
}
OS << ") {\n";
- OS << TAB_1 "if (offloadConfig().ValidationEnabled) {\n";
- // Emit validation checks
- for (const auto &Return : F.getReturns()) {
- for (auto &Condition : Return.getConditions()) {
- if (Condition.starts_with("`") && Condition.ends_with("`")) {
- auto ConditionString = Condition.substr(1, Condition.size() - 2);
- OS << formatv(TAB_2 "if ({0}) {{\n", ConditionString);
- OS << formatv(TAB_3 "return createOffloadError(error::ErrorCode::{0}, "
- "\"validation failure: {1}\");\n",
- Return.getUnprefixedValue(), ConditionString);
- OS << TAB_2 "}\n\n";
+ bool HasValidation = llvm::any_of(F.getReturns(), [](auto &R) {
+ return llvm::any_of(R.getConditions(), [](auto &C) {
+ return C.starts_with("`") && C.ends_with("`");
+ });
+ });
+
+ if (HasValidation) {
+ OS << TAB_1 "if (llvm::offload::isValidationEnabled()) {\n";
+ // Emit validation checks
+ for (const auto &Return : F.getReturns()) {
+ for (auto &Condition : Return.getConditions()) {
+ if (Condition.starts_with("`") && Condition.ends_with("`")) {
+ auto ConditionString = Condition.substr(1, Condition.size() - 2);
+ OS << formatv(TAB_2 "if ({0}) {{\n", ConditionString);
+ OS << formatv(TAB_3
+ "return createOffloadError(error::ErrorCode::{0}, "
+ "\"validation failure: {1}\");\n",
+ Return.getUnprefixedValue(), ConditionString);
+ OS << TAB_2 "}\n\n";
+ }
}
}
+ OS << TAB_1 "}\n\n";
}
- OS << TAB_1 "}\n\n";
// Perform actual function call to the implementation
ParamNameList = ParamNameList.substr(0, ParamNameList.size() - 2);
@@ -74,7 +83,7 @@ static void EmitEntryPointFunc(const FunctionRec &F, raw_ostream &OS) {
OS << ") {\n";
// Emit pre-call prints
- OS << TAB_1 "if (offloadConfig().TracingEnabled) {\n";
+ OS << TAB_1 "if (llvm::offload::isTracingEnabled()) {\n";
OS << formatv(TAB_2 "llvm::errs() << \"---> {0}\";\n", F.getName());
OS << TAB_1 "}\n\n";
@@ -85,7 +94,7 @@ static void EmitEntryPointFunc(const FunctionRec &F, raw_ostream &OS) {
PrefixLower, F.getName(), ParamNameList);
// Emit post-call prints
- OS << TAB_1 "if (offloadConfig().TracingEnabled) {\n";
+ OS << TAB_1 "if (llvm::offload::isTracingEnabled()) {\n";
if (F.getParams().size() > 0) {
OS << formatv(TAB_2 "{0} Params = {{", F.getParamStructName());
for (const auto &Param : F.getParams()) {
|
This seems to have introduced a test regression for me. 9fd22cb is passing tests, but with 53336ad and later I'm getting test failures. The result with main as of 77941eb is:
|
Likely the additions changed the text slightly and it wasn't updated? @RossBrunton How quickly can you fix it? |
I'm pretty sure I updated this, is this reproducible with a clean build? @jhuber6 I can have a look on Monday. If you want to revert for now, feel free. |
Yep, completely clean standalone build. |
This was broken as part of llvm#144494 , and just needs an update to the check lines.
#145292 Should be fixed by this, sorry about that! |
This was broken as part of #144494 , and just needs an update to the check lines.
This was broken as part of llvm#144494 , and just needs an update to the check lines.
This was broken as part of llvm#144494 , and just needs an update to the check lines.
Rather than having a number of static local variables, we now use
a single
OffloadContext
struct to store global state. This isinitialised by
olInit
, but is never deleted (de-initialization ofOffload isn't yet implemented).
The error reporting mechanism has not been moved to the struct, since
that's going to cause issues with teardown (error messages must outlive
liboffload).