Skip to content

[Offload] Move (most) global state to an OffloadContext struct #144494

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions offload/liboffload/include/OffloadImpl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@
#include "llvm/ADT/StringSet.h"
#include "llvm/Support/Error.h"

struct OffloadConfig {
bool TracingEnabled = false;
bool ValidationEnabled = true;
};

OffloadConfig &offloadConfig();
namespace llvm {
namespace offload {
bool isTracingEnabled();
bool isValidationEnabled();
} // namespace offload
} // namespace llvm

// Use the StringSet container to efficiently deduplicate repeated error
// strings (e.g. if the same error is hit constantly in a long running program)
Expand Down
82 changes: 51 additions & 31 deletions offload/liboffload/src/OffloadImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,22 +93,36 @@ struct AllocInfo {
ol_alloc_type_t Type;
};

using AllocInfoMapT = DenseMap<void *, AllocInfo>;
AllocInfoMapT &allocInfoMap() {
static AllocInfoMapT AllocInfoMap{};
return AllocInfoMap;
}
// Global shared state for liboffload
struct OffloadContext;
static OffloadContext *OffloadContextVal;
struct OffloadContext {
OffloadContext(OffloadContext &) = delete;
OffloadContext(OffloadContext &&) = delete;
Comment on lines +100 to +101
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You probably want to delete everything except the default constructor..

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added "deletes" for operator=, anything else you want me to add?

In the future, olShutDown will need to destruct this object, so I don't think ~OffloadContext could be deleted.

OffloadContext &operator=(OffloadContext &) = delete;
OffloadContext &operator=(OffloadContext &&) = delete;

bool TracingEnabled = false;
bool ValidationEnabled = true;
DenseMap<void *, AllocInfo> AllocInfoMap{};
SmallVector<ol_platform_impl_t, 4> Platforms{};

ol_device_handle_t HostDevice() {
// The host platform is always inserted last
return &Platforms.back().Devices[0];
}

using PlatformVecT = SmallVector<ol_platform_impl_t, 4>;
PlatformVecT &Platforms() {
static PlatformVecT Platforms;
return Platforms;
}
static OffloadContext &get() {
assert(OffloadContextVal);
return *OffloadContextVal;
}
};

ol_device_handle_t HostDevice() {
// The host platform is always inserted last
return &Platforms().back().Devices[0];
// If the context is uninited, then we assume tracing is disabled
bool isTracingEnabled() {
return OffloadContextVal && OffloadContext::get().TracingEnabled;
}
bool isValidationEnabled() { return OffloadContext::get().ValidationEnabled; }

template <typename HandleT> Error olDestroy(HandleT Handle) {
delete Handle;
Expand All @@ -130,18 +144,20 @@ constexpr ol_platform_backend_t pluginNameToBackend(StringRef Name) {
#include "Shared/Targets.def"

void initPlugins() {
auto *Context = new OffloadContext{};

// Attempt to create an instance of each supported plugin.
#define PLUGIN_TARGET(Name) \
do { \
Platforms().emplace_back(ol_platform_impl_t{ \
Context->Platforms.emplace_back(ol_platform_impl_t{ \
std::unique_ptr<GenericPluginTy>(createPlugin_##Name()), \
{}, \
pluginNameToBackend(#Name)}); \
} while (false);
#include "Shared/Targets.def"

// Preemptively initialize all devices in the plugin
for (auto &Platform : Platforms()) {
for (auto &Platform : Context->Platforms) {
// Do not use the host plugin - it isn't supported.
if (Platform.BackendType == OL_PLATFORM_BACKEND_UNKNOWN)
continue;
Expand All @@ -157,15 +173,16 @@ void initPlugins() {
}

// Add the special host device
auto &HostPlatform = Platforms().emplace_back(
auto &HostPlatform = Context->Platforms.emplace_back(
ol_platform_impl_t{nullptr,
{ol_device_impl_t{-1, nullptr, nullptr}},
OL_PLATFORM_BACKEND_HOST});
HostDevice()->Platform = &HostPlatform;
Context->HostDevice()->Platform = &HostPlatform;

Context->TracingEnabled = std::getenv("OFFLOAD_TRACE");
Context->ValidationEnabled = !std::getenv("OFFLOAD_DISABLE_VALIDATION");

offloadConfig().TracingEnabled = std::getenv("OFFLOAD_TRACE");
offloadConfig().ValidationEnabled =
!std::getenv("OFFLOAD_DISABLE_VALIDATION");
OffloadContextVal = Context;
}

// TODO: We can properly reference count here and manage the resources in a more
Expand Down Expand Up @@ -229,7 +246,7 @@ Error olGetDeviceInfoImplDetail(ol_device_handle_t Device,

// Find the info if it exists under any of the given names
auto GetInfo = [&](std::vector<std::string> Names) {
if (Device == HostDevice())
if (Device == OffloadContext::get().HostDevice())
return std::string("Host");

if (!Device->Device)
Expand All @@ -251,8 +268,9 @@ Error olGetDeviceInfoImplDetail(ol_device_handle_t Device,
case OL_DEVICE_INFO_PLATFORM:
return ReturnValue(Device->Platform);
case OL_DEVICE_INFO_TYPE:
return Device == HostDevice() ? ReturnValue(OL_DEVICE_TYPE_HOST)
: ReturnValue(OL_DEVICE_TYPE_GPU);
return Device == OffloadContext::get().HostDevice()
? ReturnValue(OL_DEVICE_TYPE_HOST)
: ReturnValue(OL_DEVICE_TYPE_GPU);
case OL_DEVICE_INFO_NAME:
return ReturnValue(GetInfo({"Device Name"}).c_str());
case OL_DEVICE_INFO_VENDOR:
Expand Down Expand Up @@ -280,7 +298,7 @@ Error olGetDeviceInfoSize_impl(ol_device_handle_t Device,
}

Error olIterateDevices_impl(ol_device_iterate_cb_t Callback, void *UserData) {
for (auto &Platform : Platforms()) {
for (auto &Platform : OffloadContext::get().Platforms) {
for (auto &Device : Platform.Devices) {
if (!Callback(&Device, UserData)) {
break;
Expand Down Expand Up @@ -311,24 +329,25 @@ Error olMemAlloc_impl(ol_device_handle_t Device, ol_alloc_type_t Type,
return Alloc.takeError();

*AllocationOut = *Alloc;
allocInfoMap().insert_or_assign(*Alloc, AllocInfo{Device, Type});
OffloadContext::get().AllocInfoMap.insert_or_assign(*Alloc,
AllocInfo{Device, Type});
return Error::success();
}

Error olMemFree_impl(void *Address) {
if (!allocInfoMap().contains(Address))
if (!OffloadContext::get().AllocInfoMap.contains(Address))
return createOffloadError(ErrorCode::INVALID_ARGUMENT,
"address is not a known allocation");

auto AllocInfo = allocInfoMap().at(Address);
auto AllocInfo = OffloadContext::get().AllocInfoMap.at(Address);
auto Device = AllocInfo.Device;
auto Type = AllocInfo.Type;

if (auto Res =
Device->Device->dataDelete(Address, convertOlToPluginAllocTy(Type)))
return Res;

allocInfoMap().erase(Address);
OffloadContext::get().AllocInfoMap.erase(Address);

return Error::success();
}
Expand Down Expand Up @@ -395,7 +414,8 @@ Error olMemcpy_impl(ol_queue_handle_t Queue, void *DstPtr,
ol_device_handle_t DstDevice, const void *SrcPtr,
ol_device_handle_t SrcDevice, size_t Size,
ol_event_handle_t *EventOut) {
if (DstDevice == HostDevice() && SrcDevice == HostDevice()) {
auto Host = OffloadContext::get().HostDevice();
if (DstDevice == Host && SrcDevice == Host) {
if (!Queue) {
std::memcpy(DstPtr, SrcPtr, Size);
return Error::success();
Expand All @@ -410,11 +430,11 @@ Error olMemcpy_impl(ol_queue_handle_t Queue, void *DstPtr,
// If no queue is given the memcpy will be synchronous
auto QueueImpl = Queue ? Queue->AsyncInfo : nullptr;

if (DstDevice == HostDevice()) {
if (DstDevice == Host) {
if (auto Res =
SrcDevice->Device->dataRetrieve(DstPtr, SrcPtr, Size, QueueImpl))
return Res;
} else if (SrcDevice == HostDevice()) {
} else if (SrcDevice == Host) {
if (auto Res =
DstDevice->Device->dataSubmit(DstPtr, SrcPtr, Size, QueueImpl))
return Res;
Expand Down
5 changes: 0 additions & 5 deletions offload/liboffload/src/OffloadLib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,6 @@ ol_code_location_t *&currentCodeLocation() {
return CodeLoc;
}

OffloadConfig &offloadConfig() {
static OffloadConfig Config{};
return Config;
}

namespace llvm {
namespace offload {
// Pull in the declarations for the implementation functions. The actual entry
Expand Down
37 changes: 23 additions & 14 deletions offload/tools/offload-tblgen/EntryPointGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,21 +35,30 @@ static void EmitValidationFunc(const FunctionRec &F, raw_ostream &OS) {
}
OS << ") {\n";

OS << TAB_1 "if (offloadConfig().ValidationEnabled) {\n";
// Emit validation checks
for (const auto &Return : F.getReturns()) {
for (auto &Condition : Return.getConditions()) {
if (Condition.starts_with("`") && Condition.ends_with("`")) {
auto ConditionString = Condition.substr(1, Condition.size() - 2);
OS << formatv(TAB_2 "if ({0}) {{\n", ConditionString);
OS << formatv(TAB_3 "return createOffloadError(error::ErrorCode::{0}, "
"\"validation failure: {1}\");\n",
Return.getUnprefixedValue(), ConditionString);
OS << TAB_2 "}\n\n";
bool HasValidation = llvm::any_of(F.getReturns(), [](auto &R) {
return llvm::any_of(R.getConditions(), [](auto &C) {
return C.starts_with("`") && C.ends_with("`");
});
});

if (HasValidation) {
OS << TAB_1 "if (llvm::offload::isValidationEnabled()) {\n";
// Emit validation checks
for (const auto &Return : F.getReturns()) {
for (auto &Condition : Return.getConditions()) {
if (Condition.starts_with("`") && Condition.ends_with("`")) {
auto ConditionString = Condition.substr(1, Condition.size() - 2);
OS << formatv(TAB_2 "if ({0}) {{\n", ConditionString);
OS << formatv(TAB_3
"return createOffloadError(error::ErrorCode::{0}, "
"\"validation failure: {1}\");\n",
Return.getUnprefixedValue(), ConditionString);
OS << TAB_2 "}\n\n";
}
}
}
OS << TAB_1 "}\n\n";
}
OS << TAB_1 "}\n\n";

// Perform actual function call to the implementation
ParamNameList = ParamNameList.substr(0, ParamNameList.size() - 2);
Expand All @@ -74,7 +83,7 @@ static void EmitEntryPointFunc(const FunctionRec &F, raw_ostream &OS) {
OS << ") {\n";

// Emit pre-call prints
OS << TAB_1 "if (offloadConfig().TracingEnabled) {\n";
OS << TAB_1 "if (llvm::offload::isTracingEnabled()) {\n";
OS << formatv(TAB_2 "llvm::errs() << \"---> {0}\";\n", F.getName());
OS << TAB_1 "}\n\n";

Expand All @@ -85,7 +94,7 @@ static void EmitEntryPointFunc(const FunctionRec &F, raw_ostream &OS) {
PrefixLower, F.getName(), ParamNameList);

// Emit post-call prints
OS << TAB_1 "if (offloadConfig().TracingEnabled) {\n";
OS << TAB_1 "if (llvm::offload::isTracingEnabled()) {\n";
if (F.getParams().size() > 0) {
OS << formatv(TAB_2 "{0} Params = {{", F.getParamStructName());
for (const auto &Param : F.getParams()) {
Expand Down
Loading