Skip to content

Commit 41facc9

Browse files
committed
[Offload] Move RPC server handling to a dedicated thread
Summary: Handling the RPC server requires running through list of jobs that the device has requested to be done. Currently this is handled by the thread that does the waiting for the kernel to finish. However, this is not sound on NVIDIA architectures and only works for async launches in the OpenMP model that uses helper threads. However, we also don't want to have this thread doing work unnnecessarily. For this reason we track the execution of kernels and cause the thread to sleep via a condition variable (usually backed by some kind of futex or other intelligent sleeping mechanism) so that the thread will be idle while no kernels are running. Use cuLaunchHostFunc Only create thread if used
1 parent a6ef0de commit 41facc9

File tree

8 files changed

+281
-88
lines changed

8 files changed

+281
-88
lines changed

offload/plugins-nextgen/amdgpu/src/rtl.cpp

Lines changed: 32 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -621,9 +621,9 @@ struct AMDGPUSignalTy {
621621
}
622622

623623
/// Wait until the signal gets a zero value.
624-
Error wait(const uint64_t ActiveTimeout = 0, RPCServerTy *RPCServer = nullptr,
624+
Error wait(const uint64_t ActiveTimeout = 0,
625625
GenericDeviceTy *Device = nullptr) const {
626-
if (ActiveTimeout && !RPCServer) {
626+
if (ActiveTimeout) {
627627
hsa_signal_value_t Got = 1;
628628
Got = hsa_signal_wait_scacquire(HSASignal, HSA_SIGNAL_CONDITION_EQ, 0,
629629
ActiveTimeout, HSA_WAIT_STATE_ACTIVE);
@@ -632,14 +632,11 @@ struct AMDGPUSignalTy {
632632
}
633633

634634
// If there is an RPC device attached to this stream we run it as a server.
635-
uint64_t Timeout = RPCServer ? 8192 : UINT64_MAX;
636-
auto WaitState = RPCServer ? HSA_WAIT_STATE_ACTIVE : HSA_WAIT_STATE_BLOCKED;
635+
uint64_t Timeout = UINT64_MAX;
636+
auto WaitState = HSA_WAIT_STATE_BLOCKED;
637637
while (hsa_signal_wait_scacquire(HSASignal, HSA_SIGNAL_CONDITION_EQ, 0,
638-
Timeout, WaitState) != 0) {
639-
if (RPCServer && Device)
640-
if (auto Err = RPCServer->runServer(*Device))
641-
return Err;
642-
}
638+
Timeout, WaitState) != 0)
639+
;
643640
return Plugin::success();
644641
}
645642

@@ -1048,11 +1045,6 @@ struct AMDGPUStreamTy {
10481045
/// operation that was already finalized in a previous stream sycnhronize.
10491046
uint32_t SyncCycle;
10501047

1051-
/// A pointer associated with an RPC server running on the given device. If
1052-
/// RPC is not being used this will be a null pointer. Otherwise, this
1053-
/// indicates that an RPC server is expected to be run on this stream.
1054-
RPCServerTy *RPCServer;
1055-
10561048
/// Mutex to protect stream's management.
10571049
mutable std::mutex Mutex;
10581050

@@ -1232,9 +1224,6 @@ struct AMDGPUStreamTy {
12321224
/// Deinitialize the stream's signals.
12331225
Error deinit() { return Plugin::success(); }
12341226

1235-
/// Attach an RPC server to this stream.
1236-
void setRPCServer(RPCServerTy *Server) { RPCServer = Server; }
1237-
12381227
/// Push a asynchronous kernel to the stream. The kernel arguments must be
12391228
/// placed in a special allocation for kernel args and must keep alive until
12401229
/// the kernel finalizes. Once the kernel is finished, the stream will release
@@ -1262,10 +1251,30 @@ struct AMDGPUStreamTy {
12621251
if (auto Err = Slots[Curr].schedReleaseBuffer(KernelArgs, MemoryManager))
12631252
return Err;
12641253

1254+
// If we are running an RPC server we want to wake up the server thread
1255+
// whenever there is a kernel running and let it sleep otherwise.
1256+
if (Device.getRPCServer())
1257+
Device.Plugin.getRPCServer().Thread->notify();
1258+
12651259
// Push the kernel with the output signal and an input signal (optional)
1266-
return Queue->pushKernelLaunch(Kernel, KernelArgs, NumThreads, NumBlocks,
1267-
GroupSize, StackSize, OutputSignal,
1268-
InputSignal);
1260+
if (auto Err = Queue->pushKernelLaunch(Kernel, KernelArgs, NumThreads,
1261+
NumBlocks, GroupSize, StackSize,
1262+
OutputSignal, InputSignal))
1263+
return Err;
1264+
1265+
// Register a callback to indicate when the kernel is complete.
1266+
if (Device.getRPCServer()) {
1267+
if (auto Err = Slots[Curr].schedCallback(
1268+
[](void *Data) -> llvm::Error {
1269+
GenericPluginTy &Plugin =
1270+
*reinterpret_cast<GenericPluginTy *>(Data);
1271+
Plugin.getRPCServer().Thread->finish();
1272+
return Error::success();
1273+
},
1274+
&Device.Plugin))
1275+
return Err;
1276+
}
1277+
return Plugin::success();
12691278
}
12701279

12711280
/// Push an asynchronous memory copy between pinned memory buffers.
@@ -1475,8 +1484,8 @@ struct AMDGPUStreamTy {
14751484
return Plugin::success();
14761485

14771486
// Wait until all previous operations on the stream have completed.
1478-
if (auto Err = Slots[last()].Signal->wait(StreamBusyWaitMicroseconds,
1479-
RPCServer, &Device))
1487+
if (auto Err =
1488+
Slots[last()].Signal->wait(StreamBusyWaitMicroseconds, &Device))
14801489
return Err;
14811490

14821491
// Reset the stream and perform all pending post actions.
@@ -3025,7 +3034,7 @@ AMDGPUStreamTy::AMDGPUStreamTy(AMDGPUDeviceTy &Device)
30253034
: Agent(Device.getAgent()), Queue(nullptr),
30263035
SignalManager(Device.getSignalManager()), Device(Device),
30273036
// Initialize the std::deque with some empty positions.
3028-
Slots(32), NextSlot(0), SyncCycle(0), RPCServer(nullptr),
3037+
Slots(32), NextSlot(0), SyncCycle(0),
30293038
StreamBusyWaitMicroseconds(Device.getStreamBusyWaitMicroseconds()),
30303039
UseMultipleSdmaEngines(Device.useMultipleSdmaEngines()) {}
30313040

@@ -3378,10 +3387,6 @@ Error AMDGPUKernelTy::launchImpl(GenericDeviceTy &GenericDevice,
33783387
if (auto Err = AMDGPUDevice.getStream(AsyncInfoWrapper, Stream))
33793388
return Err;
33803389

3381-
// If this kernel requires an RPC server we attach its pointer to the stream.
3382-
if (GenericDevice.getRPCServer())
3383-
Stream->setRPCServer(GenericDevice.getRPCServer());
3384-
33853390
// Only COV5 implicitargs needs to be set. COV4 implicitargs are not used.
33863391
if (ImplArgs &&
33873392
getImplicitArgsSize() == sizeof(hsa_utils::AMDGPUImplicitArgsTy)) {

offload/plugins-nextgen/common/include/RPC.h

Lines changed: 75 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,11 @@
1919
#include "llvm/ADT/DenseMap.h"
2020
#include "llvm/Support/Error.h"
2121

22+
#include <atomic>
23+
#include <condition_variable>
2224
#include <cstdint>
25+
#include <mutex>
26+
#include <thread>
2327

2428
namespace llvm::omp::target {
2529
namespace plugin {
@@ -37,6 +41,12 @@ struct RPCServerTy {
3741
/// Initializes the handles to the number of devices we may need to service.
3842
RPCServerTy(plugin::GenericPluginTy &Plugin);
3943

44+
/// Deinitialize the associated memory and resources.
45+
llvm::Error shutDown();
46+
47+
/// Initialize the worker thread.
48+
llvm::Error startThread();
49+
4050
/// Check if this device image is using an RPC server. This checks for the
4151
/// precense of an externally visible symbol in the device image that will
4252
/// be present whenever RPC code is called.
@@ -51,17 +61,77 @@ struct RPCServerTy {
5161
plugin::GenericGlobalHandlerTy &Handler,
5262
plugin::DeviceImageTy &Image);
5363

54-
/// Runs the RPC server associated with the \p Device until the pending work
55-
/// is cleared.
56-
llvm::Error runServer(plugin::GenericDeviceTy &Device);
57-
5864
/// Deinitialize the RPC server for the given device. This will free the
5965
/// memory associated with the k
6066
llvm::Error deinitDevice(plugin::GenericDeviceTy &Device);
6167

6268
private:
6369
/// Array from this device's identifier to its attached devices.
64-
llvm::SmallVector<void *> Buffers;
70+
std::unique_ptr<void *[]> Buffers;
71+
72+
/// Array of associated devices. These must be alive as long as the server is.
73+
std::unique_ptr<plugin::GenericDeviceTy *[]> Devices;
74+
75+
/// A helper class for running the user thread that handles the RPC interface.
76+
/// Because we only need to check the RPC server while any kernels are
77+
/// working, we track submission / completion events to allow the thread to
78+
/// sleep when it is not needed.
79+
struct ServerThread {
80+
std::thread Worker;
81+
82+
/// A boolean indicating whether or not the worker thread should continue.
83+
std::atomic<bool> Running;
84+
85+
/// The number of currently executing kernels across all devices that need
86+
/// the server thread to be running.
87+
std::atomic<uint32_t> NumUsers;
88+
89+
/// The condition variable used to suspend the thread if no work is needed.
90+
std::condition_variable CV;
91+
std::mutex Mutex;
92+
93+
/// A reference to all the RPC interfaces that the server is handling.
94+
llvm::ArrayRef<void *> Buffers;
95+
96+
/// A reference to the associated generic device for the buffer.
97+
llvm::ArrayRef<plugin::GenericDeviceTy *> Devices;
98+
99+
/// Initialize the worker thread to run in the background.
100+
ServerThread(void *Buffers[], plugin::GenericDeviceTy *Devices[],
101+
size_t Length)
102+
: Running(true), NumUsers(0), CV(), Mutex(), Buffers(Buffers, Length),
103+
Devices(Devices, Length) {}
104+
105+
~ServerThread() { assert(!Running && "Thread not shut down explicitly\n"); }
106+
107+
/// Notify the worker thread that there is a user that needs it.
108+
void notify() {
109+
std::lock_guard<decltype(Mutex)> Lock(Mutex);
110+
NumUsers.fetch_add(1, std::memory_order_relaxed);
111+
CV.notify_all();
112+
}
113+
114+
/// Indicate that one of the dependent users has finished.
115+
void finish() {
116+
[[maybe_unused]] uint32_t Old =
117+
NumUsers.fetch_sub(1, std::memory_order_relaxed);
118+
assert(Old > 0 && "Attempt to signal finish with no pending work");
119+
}
120+
121+
/// Destroy the worker thread and wait.
122+
void shutDown();
123+
124+
/// Initialize the worker thread.
125+
void startThread();
126+
127+
/// Run the server thread to continuously check the RPC interface for work
128+
/// to be done for the device.
129+
void run();
130+
};
131+
132+
public:
133+
/// Pointer to the server thread instance.
134+
std::unique_ptr<ServerThread> Thread;
65135
};
66136

67137
} // namespace llvm::omp::target

offload/plugins-nextgen/common/src/PluginInterface.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1051,6 +1051,9 @@ Error GenericDeviceTy::setupRPCServer(GenericPluginTy &Plugin,
10511051
if (auto Err = Server.initDevice(*this, Plugin.getGlobalHandler(), Image))
10521052
return Err;
10531053

1054+
if (auto Err = Server.startThread())
1055+
return Err;
1056+
10541057
RPCServer = &Server;
10551058
DP("Running an RPC server on device %d\n", getDeviceId());
10561059
return Plugin::success();
@@ -1624,8 +1627,11 @@ Error GenericPluginTy::deinit() {
16241627
if (GlobalHandler)
16251628
delete GlobalHandler;
16261629

1627-
if (RPCServer)
1630+
if (RPCServer) {
1631+
if (Error Err = RPCServer->shutDown())
1632+
return Err;
16281633
delete RPCServer;
1634+
}
16291635

16301636
if (RecordReplay)
16311637
delete RecordReplay;

0 commit comments

Comments
 (0)