Skip to content

Commit 015420f

Browse files
authored
Merge pull request llvm#325 from AMD-Lightning-Internal/amd/dev/rlieberm/restore-move-rpc
[Offload] Move RPC server handling to a dedicated thread (llvm#112988)
2 parents 89c375d + 86647c1 commit 015420f

File tree

10 files changed

+288
-128
lines changed

10 files changed

+288
-128
lines changed

offload/plugins-nextgen/amdgpu/src/rtl.cpp

Lines changed: 34 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1382,9 +1382,9 @@ struct AMDGPUSignalTy {
13821382
}
13831383

13841384
/// Wait until the signal gets a zero value.
1385-
Error wait(const uint64_t ActiveTimeout = 0, RPCServerTy *RPCServer = nullptr,
1385+
Error wait(const uint64_t ActiveTimeout = 0,
13861386
GenericDeviceTy *Device = nullptr) const {
1387-
if (ActiveTimeout && !RPCServer) {
1387+
if (ActiveTimeout) {
13881388
hsa_signal_value_t Got = 1;
13891389
Got = hsa_signal_wait_scacquire(HSASignal, HSA_SIGNAL_CONDITION_EQ, 0,
13901390
ActiveTimeout, HSA_WAIT_STATE_ACTIVE);
@@ -1393,14 +1393,11 @@ struct AMDGPUSignalTy {
13931393
}
13941394

13951395
// If there is an RPC device attached to this stream we run it as a server.
1396-
uint64_t Timeout = RPCServer ? 8192 : UINT64_MAX;
1397-
auto WaitState = RPCServer ? HSA_WAIT_STATE_ACTIVE : HSA_WAIT_STATE_BLOCKED;
1396+
uint64_t Timeout = UINT64_MAX;
1397+
auto WaitState = HSA_WAIT_STATE_BLOCKED;
13981398
while (hsa_signal_wait_scacquire(HSASignal, HSA_SIGNAL_CONDITION_EQ, 0,
1399-
Timeout, WaitState) != 0) {
1400-
if (RPCServer && Device)
1401-
if (auto Err = RPCServer->runServer(*Device))
1402-
return Err;
1403-
}
1399+
Timeout, WaitState) != 0)
1400+
;
14041401
return Plugin::success();
14051402
}
14061403

@@ -1895,11 +1892,6 @@ struct AMDGPUStreamTy {
18951892
/// operation that was already finalized in a previous stream sycnhronize.
18961893
uint32_t SyncCycle;
18971894

1898-
/// A pointer associated with an RPC server running on the given device. If
1899-
/// RPC is not being used this will be a null pointer. Otherwise, this
1900-
/// indicates that an RPC server is expected to be run on this stream.
1901-
RPCServerTy *RPCServer;
1902-
19031895
/// Mutex to protect stream's management.
19041896
mutable std::mutex Mutex;
19051897

@@ -2136,9 +2128,6 @@ struct AMDGPUStreamTy {
21362128

21372129
hsa_queue_t *getHsaQueue() { return Queue->getHsaQueue(); }
21382130

2139-
/// Attach an RPC server to this stream.
2140-
void setRPCServer(RPCServerTy *Server) { RPCServer = Server; }
2141-
21422131
/// Push a asynchronous kernel to the stream. The kernel arguments must be
21432132
/// placed in a special allocation for kernel args and must keep alive until
21442133
/// the kernel finalizes. Once the kernel is finished, the stream will release
@@ -2194,9 +2183,30 @@ struct AMDGPUStreamTy {
21942183

21952184
// Push the kernel with the output signal and an input signal (optional)
21962185
DP("Using Queue: %p with HSA Queue: %p\n", Queue, Queue->getHsaQueue());
2197-
return Queue->pushKernelLaunch(Kernel, KernelArgs, NumThreads, NumBlocks,
2198-
GroupSize, StackSize, OutputSignal,
2199-
InputSignal);
2186+
// If we are running an RPC server we want to wake up the server thread
2187+
// whenever there is a kernel running and let it sleep otherwise.
2188+
if (Device.getRPCServer())
2189+
Device.Plugin.getRPCServer().Thread->notify();
2190+
2191+
// Push the kernel with the output signal and an input signal (optional)
2192+
if (auto Err = Queue->pushKernelLaunch(Kernel, KernelArgs, NumThreads,
2193+
NumBlocks, GroupSize, StackSize,
2194+
OutputSignal, InputSignal))
2195+
return Err;
2196+
2197+
// Register a callback to indicate when the kernel is complete.
2198+
if (Device.getRPCServer()) {
2199+
if (auto Err = Slots[Curr].schedCallback(
2200+
[](void *Data) -> llvm::Error {
2201+
GenericPluginTy &Plugin =
2202+
*reinterpret_cast<GenericPluginTy *>(Data);
2203+
Plugin.getRPCServer().Thread->finish();
2204+
return Error::success();
2205+
},
2206+
&Device.Plugin))
2207+
return Err;
2208+
}
2209+
return Plugin::success();
22002210
}
22012211

22022212
/// Push an asynchronous memory copy between pinned memory buffers.
@@ -2268,9 +2278,8 @@ struct AMDGPUStreamTy {
22682278

22692279
// Wait for kernel to finish before scheduling the asynchronous copy.
22702280
if (UseSyncCopyBack && InputSignal && InputSignal->load())
2271-
if (auto Err = InputSignal->wait(StreamBusyWaitMicroseconds, RPCServer, &Device))
2281+
if (auto Err = InputSignal->wait(StreamBusyWaitMicroseconds, &Device))
22722282
return Err;
2273-
22742283
#ifdef OMPT_SUPPORT
22752284

22762285
if (OmptInfo) {
@@ -2457,8 +2466,8 @@ struct AMDGPUStreamTy {
24572466
return Plugin::success();
24582467

24592468
// Wait until all previous operations on the stream have completed.
2460-
if (auto Err = Slots[last()].Signal->wait(StreamBusyWaitMicroseconds,
2461-
RPCServer, &Device))
2469+
if (auto Err =
2470+
Slots[last()].Signal->wait(StreamBusyWaitMicroseconds, &Device))
24622471
return Err;
24632472

24642473
// Reset the stream and perform all pending post actions.
@@ -4701,7 +4710,7 @@ AMDGPUStreamTy::AMDGPUStreamTy(AMDGPUDeviceTy &Device)
47014710
: Agent(Device.getAgent()), Queue(nullptr),
47024711
SignalManager(Device.getSignalManager()), Device(Device),
47034712
// Initialize the std::deque with some empty positions.
4704-
Slots(32), NextSlot(0), SyncCycle(0), RPCServer(nullptr),
4713+
Slots(32), NextSlot(0), SyncCycle(0),
47054714
StreamBusyWaitMicroseconds(Device.getStreamBusyWaitMicroseconds()),
47064715
UseMultipleSdmaEngines(Device.useMultipleSdmaEngines()),
47074716
UseSyncCopyBack(Device.syncCopyBack()) {}
@@ -5117,10 +5126,6 @@ Error AMDGPUKernelTy::launchImpl(GenericDeviceTy &GenericDevice,
51175126
DP("No hostrpc buffer or service thread required\n");
51185127
}
51195128

5120-
// If this kernel requires an RPC server we attach its pointer to the stream.
5121-
if (GenericDevice.getRPCServer())
5122-
Stream->setRPCServer(GenericDevice.getRPCServer());
5123-
51245129
// Only COV5 implicitargs needs to be set. COV4 implicitargs are not used.
51255130
if (ImplArgs &&
51265131
getImplicitArgsSize() == sizeof(hsa_utils::AMDGPUImplicitArgsTy)) {

offload/plugins-nextgen/common/include/PluginInterface.h

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -797,12 +797,6 @@ struct GenericDeviceTy : public DeviceAllocatorTy {
797797
/// Setup the global device memory pool, if the plugin requires one.
798798
Error setupDeviceMemoryPool(GenericPluginTy &Plugin, DeviceImageTy &Image,
799799
uint64_t PoolSize);
800-
801-
// Setup the RPC server for this device if needed. This may not run on some
802-
// plugins like the CPU targets. By default, it will not be executed so it is
803-
// up to the target to override this using the shouldSetupRPCServer function.
804-
Error setupRPCServer(GenericPluginTy &Plugin, DeviceImageTy &Image);
805-
806800
/// Synchronize the current thread with the pending operations on the
807801
/// __tgt_async_info structure.
808802
Error synchronize(__tgt_async_info *AsyncInfo);

offload/plugins-nextgen/common/include/RPC.h

Lines changed: 75 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,11 @@
1919
#include "llvm/ADT/DenseMap.h"
2020
#include "llvm/Support/Error.h"
2121

22+
#include <atomic>
23+
#include <condition_variable>
2224
#include <cstdint>
25+
#include <mutex>
26+
#include <thread>
2327

2428
namespace llvm::omp::target {
2529
namespace plugin {
@@ -37,6 +41,12 @@ struct RPCServerTy {
3741
/// Initializes the handles to the number of devices we may need to service.
3842
RPCServerTy(plugin::GenericPluginTy &Plugin);
3943

44+
/// Deinitialize the associated memory and resources.
45+
llvm::Error shutDown();
46+
47+
/// Initialize the worker thread.
48+
llvm::Error startThread();
49+
4050
/// Check if this device image is using an RPC server. This checks for the
4151
/// precense of an externally visible symbol in the device image that will
4252
/// be present whenever RPC code is called.
@@ -51,17 +61,77 @@ struct RPCServerTy {
5161
plugin::GenericGlobalHandlerTy &Handler,
5262
plugin::DeviceImageTy &Image);
5363

54-
/// Runs the RPC server associated with the \p Device until the pending work
55-
/// is cleared.
56-
llvm::Error runServer(plugin::GenericDeviceTy &Device);
57-
5864
/// Deinitialize the RPC server for the given device. This will free the
5965
/// memory associated with the k
6066
llvm::Error deinitDevice(plugin::GenericDeviceTy &Device);
6167

6268
private:
6369
/// Array from this device's identifier to its attached devices.
64-
llvm::SmallVector<void *> Buffers;
70+
std::unique_ptr<void *[]> Buffers;
71+
72+
/// Array of associated devices. These must be alive as long as the server is.
73+
std::unique_ptr<plugin::GenericDeviceTy *[]> Devices;
74+
75+
/// A helper class for running the user thread that handles the RPC interface.
76+
/// Because we only need to check the RPC server while any kernels are
77+
/// working, we track submission / completion events to allow the thread to
78+
/// sleep when it is not needed.
79+
struct ServerThread {
80+
std::thread Worker;
81+
82+
/// A boolean indicating whether or not the worker thread should continue.
83+
std::atomic<bool> Running;
84+
85+
/// The number of currently executing kernels across all devices that need
86+
/// the server thread to be running.
87+
std::atomic<uint32_t> NumUsers;
88+
89+
/// The condition variable used to suspend the thread if no work is needed.
90+
std::condition_variable CV;
91+
std::mutex Mutex;
92+
93+
/// A reference to all the RPC interfaces that the server is handling.
94+
llvm::ArrayRef<void *> Buffers;
95+
96+
/// A reference to the associated generic device for the buffer.
97+
llvm::ArrayRef<plugin::GenericDeviceTy *> Devices;
98+
99+
/// Initialize the worker thread to run in the background.
100+
ServerThread(void *Buffers[], plugin::GenericDeviceTy *Devices[],
101+
size_t Length)
102+
: Running(false), NumUsers(0), CV(), Mutex(), Buffers(Buffers, Length),
103+
Devices(Devices, Length) {}
104+
105+
~ServerThread() { assert(!Running && "Thread not shut down explicitly\n"); }
106+
107+
/// Notify the worker thread that there is a user that needs it.
108+
void notify() {
109+
std::lock_guard<decltype(Mutex)> Lock(Mutex);
110+
NumUsers.fetch_add(1, std::memory_order_relaxed);
111+
CV.notify_all();
112+
}
113+
114+
/// Indicate that one of the dependent users has finished.
115+
void finish() {
116+
[[maybe_unused]] uint32_t Old =
117+
NumUsers.fetch_sub(1, std::memory_order_relaxed);
118+
assert(Old > 0 && "Attempt to signal finish with no pending work");
119+
}
120+
121+
/// Destroy the worker thread and wait.
122+
void shutDown();
123+
124+
/// Initialize the worker thread.
125+
void startThread();
126+
127+
/// Run the server thread to continuously check the RPC interface for work
128+
/// to be done for the device.
129+
void run();
130+
};
131+
132+
public:
133+
/// Pointer to the server thread instance.
134+
std::unique_ptr<ServerThread> Thread;
65135
};
66136

67137
} // namespace llvm::omp::target

offload/plugins-nextgen/common/src/PluginInterface.cpp

Lines changed: 4 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1132,10 +1132,6 @@ GenericDeviceTy::loadBinary(GenericPluginTy &Plugin,
11321132
} else if (auto Err = setupDeviceMemoryPool(Plugin, *Image, HeapSize))
11331133
return std::move(Err);
11341134
}
1135-
1136-
if (auto Err = setupRPCServer(Plugin, *Image))
1137-
return std::move(Err);
1138-
11391135
#ifdef OMPT_SUPPORT
11401136
if (ompt::Initialized) {
11411137
size_t Bytes =
@@ -1249,30 +1245,6 @@ Error GenericDeviceTy::setupDeviceMemoryPool(GenericPluginTy &Plugin,
12491245
return GHandler.writeGlobalToDevice(*this, Image, DevEnvGlobal);
12501246
}
12511247

1252-
Error GenericDeviceTy::setupRPCServer(GenericPluginTy &Plugin,
1253-
DeviceImageTy &Image) {
1254-
// The plugin either does not need an RPC server or it is unavailible.
1255-
if (!shouldSetupRPCServer())
1256-
return Plugin::success();
1257-
1258-
// Check if this device needs to run an RPC server.
1259-
RPCServerTy &Server = Plugin.getRPCServer();
1260-
auto UsingOrErr =
1261-
Server.isDeviceUsingRPC(*this, Plugin.getGlobalHandler(), Image);
1262-
if (!UsingOrErr)
1263-
return UsingOrErr.takeError();
1264-
1265-
if (!UsingOrErr.get())
1266-
return Plugin::success();
1267-
1268-
if (auto Err = Server.initDevice(*this, Plugin.getGlobalHandler(), Image))
1269-
return Err;
1270-
1271-
RPCServer = &Server;
1272-
DP("Running an RPC server on device %d\n", getDeviceId());
1273-
return Plugin::success();
1274-
}
1275-
12761248
Error PinnedAllocationMapTy::insertEntry(void *HstPtr, void *DevAccessiblePtr,
12771249
size_t Size, bool ExternallyLocked) {
12781250
// Insert the new entry into the map.
@@ -1892,8 +1864,11 @@ Error GenericPluginTy::deinit() {
18921864
delete GlobalHandler;
18931865

18941866
#if RPC_FIXME
1895-
if (RPCServer)
1867+
if (RPCServer) {
1868+
if (Error Err = RPCServer->shutDown())
1869+
return Err;
18961870
delete RPCServer;
1871+
}
18971872
#endif
18981873

18991874
if (RecordReplay)

0 commit comments

Comments
 (0)