Skip to content

Commit 5812d0b

Browse files
authored
[Offload] Make only a single thread handle the RPC server thread (#126067)
Summary: This patch just changes the interface to make starting the thread multiple times permissable since it will only be done the first time. Note that this does not refcount it or anything, so it's onto the user to make sure that they don't shut down the thread before everyone is done using it. That is the case today because the shutDown portion is run by a single thread in the destructor phase. Another question is if we should make this thread truly global state, because currently it will be private to each plugin instance, so if you have an AMD and NVIDIA image there will be two, similarly if you have those inside of a shared library.
1 parent 11c3f52 commit 5812d0b

File tree

3 files changed

+9
-14
lines changed

3 files changed

+9
-14
lines changed

offload/plugins-nextgen/common/include/RPC.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ struct RPCServerTy {
8080
std::thread Worker;
8181

8282
/// A boolean indicating whether or not the worker thread should continue.
83-
std::atomic<bool> Running;
83+
std::atomic<uint32_t> Running;
8484

8585
/// The number of currently executing kernels across all devices that need
8686
/// the server thread to be running.

offload/plugins-nextgen/common/src/PluginInterface.cpp

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1058,9 +1058,8 @@ Error GenericDeviceTy::setupRPCServer(GenericPluginTy &Plugin,
10581058
if (auto Err = Server.initDevice(*this, Plugin.getGlobalHandler(), Image))
10591059
return Err;
10601060

1061-
if (!Server.Thread->Running.load(std::memory_order_acquire))
1062-
if (auto Err = Server.startThread())
1063-
return Err;
1061+
if (auto Err = Server.startThread())
1062+
return Err;
10641063

10651064
RPCServer = &Server;
10661065
DP("Running an RPC server on device %d\n", getDeviceId());
@@ -1635,12 +1634,11 @@ Error GenericPluginTy::deinit() {
16351634
if (GlobalHandler)
16361635
delete GlobalHandler;
16371636

1638-
if (RPCServer && RPCServer->Thread->Running.load(std::memory_order_acquire))
1637+
if (RPCServer) {
16391638
if (Error Err = RPCServer->shutDown())
16401639
return Err;
1641-
1642-
if (RPCServer)
16431640
delete RPCServer;
1641+
}
16441642

16451643
if (RecordReplay)
16461644
delete RecordReplay;

offload/plugins-nextgen/common/src/RPC.cpp

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -99,18 +99,15 @@ static rpc::Status runServer(plugin::GenericDeviceTy &Device, void *Buffer) {
9999
}
100100

101101
void RPCServerTy::ServerThread::startThread() {
102-
assert(!Running.load(std::memory_order_relaxed) &&
103-
"Attempting to start thread that is already running");
104-
Running.store(true, std::memory_order_release);
105-
Worker = std::thread([this]() { run(); });
102+
if (!Running.fetch_or(true, std::memory_order_acquire))
103+
Worker = std::thread([this]() { run(); });
106104
}
107105

108106
void RPCServerTy::ServerThread::shutDown() {
109-
assert(Running.load(std::memory_order_relaxed) &&
110-
"Attempting to shut down a thread that is not running");
107+
if (!Running.fetch_and(false, std::memory_order_release))
108+
return;
111109
{
112110
std::lock_guard<decltype(Mutex)> Lock(Mutex);
113-
Running.store(false, std::memory_order_release);
114111
CV.notify_all();
115112
}
116113
if (Worker.joinable())

0 commit comments

Comments
 (0)