Skip to content

Commit 3cb10a2

Browse files
jhuber6ronlieb
authored andcommitted
[Offload] Move RPC server handling to a dedicated thread (llvm#112988)
Summary: Handling the RPC server requires running through list of jobs that the device has requested to be done. Currently this is handled by the thread that does the waiting for the kernel to finish. However, this is not sound on NVIDIA architectures and only works for async launches in the OpenMP model that uses helper threads. However, we also don't want to have this thread doing work unnnecessarily. For this reason we track the execution of kernels and cause the thread to sleep via a condition variable (usually backed by some kind of futex or other intelligent sleeping mechanism) so that the thread will be idle while no kernels are running.
1 parent 1546397 commit 3cb10a2

File tree

10 files changed

+283
-128
lines changed

10 files changed

+283
-128
lines changed

offload/plugins-nextgen/amdgpu/src/rtl.cpp

Lines changed: 34 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1382,9 +1382,9 @@ struct AMDGPUSignalTy {
13821382
}
13831383

13841384
/// Wait until the signal gets a zero value.
1385-
Error wait(const uint64_t ActiveTimeout = 0, RPCServerTy *RPCServer = nullptr,
1385+
Error wait(const uint64_t ActiveTimeout = 0,
13861386
GenericDeviceTy *Device = nullptr) const {
1387-
if (ActiveTimeout && !RPCServer) {
1387+
if (ActiveTimeout) {
13881388
hsa_signal_value_t Got = 1;
13891389
Got = hsa_signal_wait_scacquire(HSASignal, HSA_SIGNAL_CONDITION_EQ, 0,
13901390
ActiveTimeout, HSA_WAIT_STATE_ACTIVE);
@@ -1393,14 +1393,11 @@ struct AMDGPUSignalTy {
13931393
}
13941394

13951395
// If there is an RPC device attached to this stream we run it as a server.
1396-
uint64_t Timeout = RPCServer ? 8192 : UINT64_MAX;
1397-
auto WaitState = RPCServer ? HSA_WAIT_STATE_ACTIVE : HSA_WAIT_STATE_BLOCKED;
1396+
uint64_t Timeout = UINT64_MAX;
1397+
auto WaitState = HSA_WAIT_STATE_BLOCKED;
13981398
while (hsa_signal_wait_scacquire(HSASignal, HSA_SIGNAL_CONDITION_EQ, 0,
1399-
Timeout, WaitState) != 0) {
1400-
if (RPCServer && Device)
1401-
if (auto Err = RPCServer->runServer(*Device))
1402-
return Err;
1403-
}
1399+
Timeout, WaitState) != 0)
1400+
;
14041401
return Plugin::success();
14051402
}
14061403

@@ -1895,11 +1892,6 @@ struct AMDGPUStreamTy {
18951892
/// operation that was already finalized in a previous stream sycnhronize.
18961893
uint32_t SyncCycle;
18971894

1898-
/// A pointer associated with an RPC server running on the given device. If
1899-
/// RPC is not being used this will be a null pointer. Otherwise, this
1900-
/// indicates that an RPC server is expected to be run on this stream.
1901-
RPCServerTy *RPCServer;
1902-
19031895
/// Mutex to protect stream's management.
19041896
mutable std::mutex Mutex;
19051897

@@ -2136,9 +2128,6 @@ struct AMDGPUStreamTy {
21362128

21372129
hsa_queue_t *getHsaQueue() { return Queue->getHsaQueue(); }
21382130

2139-
/// Attach an RPC server to this stream.
2140-
void setRPCServer(RPCServerTy *Server) { RPCServer = Server; }
2141-
21422131
/// Push a asynchronous kernel to the stream. The kernel arguments must be
21432132
/// placed in a special allocation for kernel args and must keep alive until
21442133
/// the kernel finalizes. Once the kernel is finished, the stream will release
@@ -2194,9 +2183,30 @@ struct AMDGPUStreamTy {
21942183

21952184
// Push the kernel with the output signal and an input signal (optional)
21962185
DP("Using Queue: %p with HSA Queue: %p\n", Queue, Queue->getHsaQueue());
2197-
return Queue->pushKernelLaunch(Kernel, KernelArgs, NumThreads, NumBlocks,
2198-
GroupSize, StackSize, OutputSignal,
2199-
InputSignal);
2186+
// If we are running an RPC server we want to wake up the server thread
2187+
// whenever there is a kernel running and let it sleep otherwise.
2188+
if (Device.getRPCServer())
2189+
Device.Plugin.getRPCServer().Thread->notify();
2190+
2191+
// Push the kernel with the output signal and an input signal (optional)
2192+
if (auto Err = Queue->pushKernelLaunch(Kernel, KernelArgs, NumThreads,
2193+
NumBlocks, GroupSize, StackSize,
2194+
OutputSignal, InputSignal))
2195+
return Err;
2196+
2197+
// Register a callback to indicate when the kernel is complete.
2198+
if (Device.getRPCServer()) {
2199+
if (auto Err = Slots[Curr].schedCallback(
2200+
[](void *Data) -> llvm::Error {
2201+
GenericPluginTy &Plugin =
2202+
*reinterpret_cast<GenericPluginTy *>(Data);
2203+
Plugin.getRPCServer().Thread->finish();
2204+
return Error::success();
2205+
},
2206+
&Device.Plugin))
2207+
return Err;
2208+
}
2209+
return Plugin::success();
22002210
}
22012211

22022212
/// Push an asynchronous memory copy between pinned memory buffers.
@@ -2268,9 +2278,8 @@ struct AMDGPUStreamTy {
22682278

22692279
// Wait for kernel to finish before scheduling the asynchronous copy.
22702280
if (UseSyncCopyBack && InputSignal && InputSignal->load())
2271-
if (auto Err = InputSignal->wait(StreamBusyWaitMicroseconds, RPCServer, &Device))
2281+
if (auto Err = InputSignal->wait(StreamBusyWaitMicroseconds, &Device))
22722282
return Err;
2273-
22742283
#ifdef OMPT_SUPPORT
22752284

22762285
if (OmptInfo) {
@@ -2457,8 +2466,8 @@ struct AMDGPUStreamTy {
24572466
return Plugin::success();
24582467

24592468
// Wait until all previous operations on the stream have completed.
2460-
if (auto Err = Slots[last()].Signal->wait(StreamBusyWaitMicroseconds,
2461-
RPCServer, &Device))
2469+
if (auto Err =
2470+
Slots[last()].Signal->wait(StreamBusyWaitMicroseconds, &Device))
24622471
return Err;
24632472

24642473
// Reset the stream and perform all pending post actions.
@@ -4701,7 +4710,7 @@ AMDGPUStreamTy::AMDGPUStreamTy(AMDGPUDeviceTy &Device)
47014710
: Agent(Device.getAgent()), Queue(nullptr),
47024711
SignalManager(Device.getSignalManager()), Device(Device),
47034712
// Initialize the std::deque with some empty positions.
4704-
Slots(32), NextSlot(0), SyncCycle(0), RPCServer(nullptr),
4713+
Slots(32), NextSlot(0), SyncCycle(0),
47054714
StreamBusyWaitMicroseconds(Device.getStreamBusyWaitMicroseconds()),
47064715
UseMultipleSdmaEngines(Device.useMultipleSdmaEngines()),
47074716
UseSyncCopyBack(Device.syncCopyBack()) {}
@@ -5117,10 +5126,6 @@ Error AMDGPUKernelTy::launchImpl(GenericDeviceTy &GenericDevice,
51175126
DP("No hostrpc buffer or service thread required\n");
51185127
}
51195128

5120-
// If this kernel requires an RPC server we attach its pointer to the stream.
5121-
if (GenericDevice.getRPCServer())
5122-
Stream->setRPCServer(GenericDevice.getRPCServer());
5123-
51245129
// Only COV5 implicitargs needs to be set. COV4 implicitargs are not used.
51255130
if (ImplArgs &&
51265131
getImplicitArgsSize() == sizeof(hsa_utils::AMDGPUImplicitArgsTy)) {

offload/plugins-nextgen/common/include/PluginInterface.h

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -797,12 +797,6 @@ struct GenericDeviceTy : public DeviceAllocatorTy {
797797
/// Setup the global device memory pool, if the plugin requires one.
798798
Error setupDeviceMemoryPool(GenericPluginTy &Plugin, DeviceImageTy &Image,
799799
uint64_t PoolSize);
800-
801-
// Setup the RPC server for this device if needed. This may not run on some
802-
// plugins like the CPU targets. By default, it will not be executed so it is
803-
// up to the target to override this using the shouldSetupRPCServer function.
804-
Error setupRPCServer(GenericPluginTy &Plugin, DeviceImageTy &Image);
805-
806800
/// Synchronize the current thread with the pending operations on the
807801
/// __tgt_async_info structure.
808802
Error synchronize(__tgt_async_info *AsyncInfo);

offload/plugins-nextgen/common/include/RPC.h

Lines changed: 75 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,11 @@
1919
#include "llvm/ADT/DenseMap.h"
2020
#include "llvm/Support/Error.h"
2121

22+
#include <atomic>
23+
#include <condition_variable>
2224
#include <cstdint>
25+
#include <mutex>
26+
#include <thread>
2327

2428
namespace llvm::omp::target {
2529
namespace plugin {
@@ -37,6 +41,12 @@ struct RPCServerTy {
3741
/// Initializes the handles to the number of devices we may need to service.
3842
RPCServerTy(plugin::GenericPluginTy &Plugin);
3943

44+
/// Deinitialize the associated memory and resources.
45+
llvm::Error shutDown();
46+
47+
/// Initialize the worker thread.
48+
llvm::Error startThread();
49+
4050
/// Check if this device image is using an RPC server. This checks for the
4151
/// precense of an externally visible symbol in the device image that will
4252
/// be present whenever RPC code is called.
@@ -51,17 +61,77 @@ struct RPCServerTy {
5161
plugin::GenericGlobalHandlerTy &Handler,
5262
plugin::DeviceImageTy &Image);
5363

54-
/// Runs the RPC server associated with the \p Device until the pending work
55-
/// is cleared.
56-
llvm::Error runServer(plugin::GenericDeviceTy &Device);
57-
5864
/// Deinitialize the RPC server for the given device. This will free the
5965
/// memory associated with the k
6066
llvm::Error deinitDevice(plugin::GenericDeviceTy &Device);
6167

6268
private:
6369
/// Array from this device's identifier to its attached devices.
64-
llvm::SmallVector<void *> Buffers;
70+
std::unique_ptr<void *[]> Buffers;
71+
72+
/// Array of associated devices. These must be alive as long as the server is.
73+
std::unique_ptr<plugin::GenericDeviceTy *[]> Devices;
74+
75+
/// A helper class for running the user thread that handles the RPC interface.
76+
/// Because we only need to check the RPC server while any kernels are
77+
/// working, we track submission / completion events to allow the thread to
78+
/// sleep when it is not needed.
79+
struct ServerThread {
80+
std::thread Worker;
81+
82+
/// A boolean indicating whether or not the worker thread should continue.
83+
std::atomic<bool> Running;
84+
85+
/// The number of currently executing kernels across all devices that need
86+
/// the server thread to be running.
87+
std::atomic<uint32_t> NumUsers;
88+
89+
/// The condition variable used to suspend the thread if no work is needed.
90+
std::condition_variable CV;
91+
std::mutex Mutex;
92+
93+
/// A reference to all the RPC interfaces that the server is handling.
94+
llvm::ArrayRef<void *> Buffers;
95+
96+
/// A reference to the associated generic device for the buffer.
97+
llvm::ArrayRef<plugin::GenericDeviceTy *> Devices;
98+
99+
/// Initialize the worker thread to run in the background.
100+
ServerThread(void *Buffers[], plugin::GenericDeviceTy *Devices[],
101+
size_t Length)
102+
: Running(true), NumUsers(0), CV(), Mutex(), Buffers(Buffers, Length),
103+
Devices(Devices, Length) {}
104+
105+
~ServerThread() { assert(!Running && "Thread not shut down explicitly\n"); }
106+
107+
/// Notify the worker thread that there is a user that needs it.
108+
void notify() {
109+
std::lock_guard<decltype(Mutex)> Lock(Mutex);
110+
NumUsers.fetch_add(1, std::memory_order_relaxed);
111+
CV.notify_all();
112+
}
113+
114+
/// Indicate that one of the dependent users has finished.
115+
void finish() {
116+
[[maybe_unused]] uint32_t Old =
117+
NumUsers.fetch_sub(1, std::memory_order_relaxed);
118+
assert(Old > 0 && "Attempt to signal finish with no pending work");
119+
}
120+
121+
/// Destroy the worker thread and wait.
122+
void shutDown();
123+
124+
/// Initialize the worker thread.
125+
void startThread();
126+
127+
/// Run the server thread to continuously check the RPC interface for work
128+
/// to be done for the device.
129+
void run();
130+
};
131+
132+
public:
133+
/// Pointer to the server thread instance.
134+
std::unique_ptr<ServerThread> Thread;
65135
};
66136

67137
} // namespace llvm::omp::target

offload/plugins-nextgen/common/src/PluginInterface.cpp

Lines changed: 4 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1132,10 +1132,6 @@ GenericDeviceTy::loadBinary(GenericPluginTy &Plugin,
11321132
} else if (auto Err = setupDeviceMemoryPool(Plugin, *Image, HeapSize))
11331133
return std::move(Err);
11341134
}
1135-
1136-
if (auto Err = setupRPCServer(Plugin, *Image))
1137-
return std::move(Err);
1138-
11391135
#ifdef OMPT_SUPPORT
11401136
if (ompt::Initialized) {
11411137
size_t Bytes =
@@ -1249,30 +1245,6 @@ Error GenericDeviceTy::setupDeviceMemoryPool(GenericPluginTy &Plugin,
12491245
return GHandler.writeGlobalToDevice(*this, Image, DevEnvGlobal);
12501246
}
12511247

1252-
Error GenericDeviceTy::setupRPCServer(GenericPluginTy &Plugin,
1253-
DeviceImageTy &Image) {
1254-
// The plugin either does not need an RPC server or it is unavailible.
1255-
if (!shouldSetupRPCServer())
1256-
return Plugin::success();
1257-
1258-
// Check if this device needs to run an RPC server.
1259-
RPCServerTy &Server = Plugin.getRPCServer();
1260-
auto UsingOrErr =
1261-
Server.isDeviceUsingRPC(*this, Plugin.getGlobalHandler(), Image);
1262-
if (!UsingOrErr)
1263-
return UsingOrErr.takeError();
1264-
1265-
if (!UsingOrErr.get())
1266-
return Plugin::success();
1267-
1268-
if (auto Err = Server.initDevice(*this, Plugin.getGlobalHandler(), Image))
1269-
return Err;
1270-
1271-
RPCServer = &Server;
1272-
DP("Running an RPC server on device %d\n", getDeviceId());
1273-
return Plugin::success();
1274-
}
1275-
12761248
Error PinnedAllocationMapTy::insertEntry(void *HstPtr, void *DevAccessiblePtr,
12771249
size_t Size, bool ExternallyLocked) {
12781250
// Insert the new entry into the map.
@@ -1892,8 +1864,11 @@ Error GenericPluginTy::deinit() {
18921864
delete GlobalHandler;
18931865

18941866
#if RPC_FIXME
1895-
if (RPCServer)
1867+
if (RPCServer) {
1868+
if (Error Err = RPCServer->shutDown())
1869+
return Err;
18961870
delete RPCServer;
1871+
}
18971872
#endif
18981873

18991874
if (RecordReplay)

0 commit comments

Comments
 (0)