Skip to content

[Libomptarget] Remove global ctor and use reference counting #80499

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions openmp/libomptarget/include/PluginManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,12 @@ struct PluginManager {
ProtectedObj<DeviceContainerTy> Devices;
};

/// Initialize the plugin manager and OpenMP runtime.
void initRuntime();

/// Deinitialize the plugin and delete it.
void deinitRuntime();

extern PluginManager *PM;

#endif // OMPTARGET_PLUGIN_MANAGER_H
6 changes: 6 additions & 0 deletions openmp/libomptarget/include/omptarget.h
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,12 @@ void *llvm_omp_target_dynamic_shared_alloc();
/// add the clauses of the requires directives in a given file
void __tgt_register_requires(int64_t Flags);

/// Initializes the runtime library.
void __tgt_rtl_init();

/// Deinitializes the runtime library.
void __tgt_rtl_deinit();

/// adds a target shared library to the target execution image
void __tgt_register_lib(__tgt_bin_desc *Desc);

Expand Down
38 changes: 26 additions & 12 deletions openmp/libomptarget/src/OffloadRTL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,25 +20,39 @@
extern void llvm::omp::target::ompt::connectLibrary();
#endif

__attribute__((constructor(101))) void init() {
static std::mutex PluginMtx;
static uint32_t RefCount = 0;

void initRuntime() {
std::scoped_lock<decltype(PluginMtx)> Lock(PluginMtx);
Profiler::get();
TIMESCOPE();

DP("Init offload library!\n");

PM = new PluginManager();
if (PM == nullptr)
PM = new PluginManager();

RefCount++;
if (RefCount == 1) {
DP("Init offload library!\n");
#ifdef OMPT_SUPPORT
// Initialize OMPT first
llvm::omp::target::ompt::connectLibrary();
// Initialize OMPT first
llvm::omp::target::ompt::connectLibrary();
#endif

PM->init();

PM->registerDelayedLibraries();
PM->init();
PM->registerDelayedLibraries();
}
}

__attribute__((destructor(101))) void deinit() {
DP("Deinit offload library!\n");
delete PM;
void deinitRuntime() {
std::scoped_lock<decltype(PluginMtx)> Lock(PluginMtx);
assert(PM && "Runtime not initialized");

if (RefCount == 1) {
DP("Deinit offload library!\n");
delete PM;
PM = nullptr;
}

RefCount--;
}
2 changes: 1 addition & 1 deletion openmp/libomptarget/src/PluginManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
using namespace llvm;
using namespace llvm::sys;

PluginManager *PM;
PluginManager *PM = nullptr;

// List of all plugins that can support offloading.
static const char *RTLNames[] = {ENABLED_OFFLOAD_PLUGINS};
Expand Down
2 changes: 2 additions & 0 deletions openmp/libomptarget/src/exports
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
VERS1.0 {
global:
__tgt_rtl_init;
__tgt_rtl_deinit;
__tgt_register_requires;
__tgt_register_lib;
__tgt_unregister_lib;
Expand Down
20 changes: 18 additions & 2 deletions openmp/libomptarget/src/interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,13 @@ EXTERN void __tgt_register_requires(int64_t Flags) {
__PRETTY_FUNCTION__);
}

EXTERN void __tgt_rtl_init() { initRuntime(); }
EXTERN void __tgt_rtl_deinit() { deinitRuntime(); }

////////////////////////////////////////////////////////////////////////////////
/// adds a target shared library to the target execution image
EXTERN void __tgt_register_lib(__tgt_bin_desc *Desc) {
initRuntime();
if (PM->delayRegisterLib(Desc))
return;

Expand All @@ -49,12 +53,17 @@ EXTERN void __tgt_register_lib(__tgt_bin_desc *Desc) {

////////////////////////////////////////////////////////////////////////////////
/// Initialize all available devices without registering any image
EXTERN void __tgt_init_all_rtls() { PM->initAllPlugins(); }
EXTERN void __tgt_init_all_rtls() {
assert(PM && "Runtime not initialized");
PM->initAllPlugins();
}

////////////////////////////////////////////////////////////////////////////////
/// unloads a target shared library
EXTERN void __tgt_unregister_lib(__tgt_bin_desc *Desc) {
PM->unregisterLib(Desc);

deinitRuntime();
}

template <typename TargetAsyncInfoTy>
Expand All @@ -64,6 +73,7 @@ targetData(ident_t *Loc, int64_t DeviceId, int32_t ArgNum, void **ArgsBase,
map_var_info_t *ArgNames, void **ArgMappers,
TargetDataFuncPtrTy TargetDataFunction, const char *RegionTypeMsg,
const char *RegionName) {
assert(PM && "Runtime not initialized");
static_assert(std::is_convertible_v<TargetAsyncInfoTy, AsyncInfoTy>,
"TargetAsyncInfoTy must be convertible to AsyncInfoTy.");

Expand Down Expand Up @@ -239,6 +249,7 @@ template <typename TargetAsyncInfoTy>
static inline int targetKernel(ident_t *Loc, int64_t DeviceId, int32_t NumTeams,
int32_t ThreadLimit, void *HostPtr,
KernelArgsTy *KernelArgs) {
assert(PM && "Runtime not initialized");
static_assert(std::is_convertible_v<TargetAsyncInfoTy, AsyncInfoTy>,
"Target AsyncInfoTy must be convertible to AsyncInfoTy.");
DP("Entering target region for device %" PRId64 " with entry point " DPxMOD
Expand Down Expand Up @@ -345,6 +356,7 @@ EXTERN int __tgt_activate_record_replay(int64_t DeviceId, uint64_t MemorySize,
void *VAddr, bool IsRecord,
bool SaveOutput,
uint64_t &ReqPtrArgOffset) {
assert(PM && "Runtime not initialized");
OMPT_IF_BUILT(ReturnAddressSetterRAII RA(__builtin_return_address(0)));
auto DeviceOrErr = PM->getDevice(DeviceId);
if (!DeviceOrErr)
Expand Down Expand Up @@ -380,7 +392,7 @@ EXTERN int __tgt_target_kernel_replay(ident_t *Loc, int64_t DeviceId,
ptrdiff_t *TgtOffsets, int32_t NumArgs,
int32_t NumTeams, int32_t ThreadLimit,
uint64_t LoopTripCount) {

assert(PM && "Runtime not initialized");
OMPT_IF_BUILT(ReturnAddressSetterRAII RA(__builtin_return_address(0)));
if (checkDeviceAndCtors(DeviceId, Loc)) {
DP("Not offloading to device %" PRId64 "\n", DeviceId);
Expand Down Expand Up @@ -431,6 +443,7 @@ EXTERN void __tgt_push_mapper_component(void *RtMapperHandle, void *Base,
}

EXTERN void __tgt_set_info_flag(uint32_t NewInfoLevel) {
assert(PM && "Runtime not initialized");
std::atomic<uint32_t> &InfoLevel = getInfoLevelInternal();
InfoLevel.store(NewInfoLevel);
for (auto &R : PM->pluginAdaptors()) {
Expand All @@ -440,6 +453,7 @@ EXTERN void __tgt_set_info_flag(uint32_t NewInfoLevel) {
}

EXTERN int __tgt_print_device_info(int64_t DeviceId) {
assert(PM && "Runtime not initialized");
auto DeviceOrErr = PM->getDevice(DeviceId);
if (!DeviceOrErr)
FATAL_MESSAGE(DeviceId, "%s", toString(DeviceOrErr.takeError()).c_str());
Expand All @@ -448,7 +462,9 @@ EXTERN int __tgt_print_device_info(int64_t DeviceId) {
}

EXTERN void __tgt_target_nowait_query(void **AsyncHandle) {
assert(PM && "Runtime not initialized");
OMPT_IF_BUILT(ReturnAddressSetterRAII RA(__builtin_return_address(0)));

if (!AsyncHandle || !*AsyncHandle) {
FATAL_MESSAGE0(
1, "Receive an invalid async handle from the current OpenMP task. Is "
Expand Down
30 changes: 30 additions & 0 deletions openmp/libomptarget/test/offloading/runtime_init.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// RUN: %libomptarget-compile-generic
// RUN: env LIBOMPTARGET_DEBUG=1 %libomptarget-run-generic 2>&1 \
// RUN: %fcheck-generic

// REQUIRES: libomptarget-debug

#include <omp.h>
#include <stdio.h>

extern void __tgt_rtl_init(void);
extern void __tgt_rtl_deinit(void);

// Sanity checks to make sure that this works and is thread safe.
int main() {
// CHECK: Init offload library!
// CHECK: Deinit offload library!
__tgt_rtl_init();
#pragma omp parallel num_threads(8)
{
__tgt_rtl_init();
__tgt_rtl_deinit();
}
__tgt_rtl_deinit();

__tgt_rtl_init();
__tgt_rtl_deinit();

// CHECK: PASS
printf("PASS\n");
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test is not perfect but I'm unsure how to improve it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could add a test that dlopens a bunch of different libraries in parallel or something. Realistically this needs to be a unit test, but we're missing a bit of infrastructure right now. Figured that would come later.