Skip to content

Commit 6bad32a

Browse files
committed
[Offload] Move RPC server handling to a dedicated thread
Summary: Handling the RPC server requires running through list of jobs that the device has requested to be done. Currently this is handled by the thread that does the waiting for the kernel to finish. However, this is not sound on NVIDIA architectures and only works for async launches in the OpenMP model that uses helper threads. However, we also don't want to have this thread doing work unnnecessarily. For this reason we track the execution of kernels and cause the thread to sleep via a condition variable (usually backed by some kind of futex or other intelligent sleeping mechanism) so that the thread will be idle while no kernels are running. Use cuLaunchHostFunc Only create thread if used
1 parent d661aea commit 6bad32a

File tree

8 files changed

+245
-67
lines changed

8 files changed

+245
-67
lines changed

offload/plugins-nextgen/amdgpu/src/rtl.cpp

Lines changed: 32 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -626,9 +626,9 @@ struct AMDGPUSignalTy {
626626
}
627627

628628
/// Wait until the signal gets a zero value.
629-
Error wait(const uint64_t ActiveTimeout = 0, RPCServerTy *RPCServer = nullptr,
629+
Error wait(const uint64_t ActiveTimeout = 0,
630630
GenericDeviceTy *Device = nullptr) const {
631-
if (ActiveTimeout && !RPCServer) {
631+
if (ActiveTimeout) {
632632
hsa_signal_value_t Got = 1;
633633
Got = hsa_signal_wait_scacquire(HSASignal, HSA_SIGNAL_CONDITION_EQ, 0,
634634
ActiveTimeout, HSA_WAIT_STATE_ACTIVE);
@@ -637,14 +637,11 @@ struct AMDGPUSignalTy {
637637
}
638638

639639
// If there is an RPC device attached to this stream we run it as a server.
640-
uint64_t Timeout = RPCServer ? 8192 : UINT64_MAX;
641-
auto WaitState = RPCServer ? HSA_WAIT_STATE_ACTIVE : HSA_WAIT_STATE_BLOCKED;
640+
uint64_t Timeout = UINT64_MAX;
641+
auto WaitState = HSA_WAIT_STATE_BLOCKED;
642642
while (hsa_signal_wait_scacquire(HSASignal, HSA_SIGNAL_CONDITION_EQ, 0,
643-
Timeout, WaitState) != 0) {
644-
if (RPCServer && Device)
645-
if (auto Err = RPCServer->runServer(*Device))
646-
return Err;
647-
}
643+
Timeout, WaitState) != 0)
644+
;
648645
return Plugin::success();
649646
}
650647

@@ -1053,11 +1050,6 @@ struct AMDGPUStreamTy {
10531050
/// operation that was already finalized in a previous stream sycnhronize.
10541051
uint32_t SyncCycle;
10551052

1056-
/// A pointer associated with an RPC server running on the given device. If
1057-
/// RPC is not being used this will be a null pointer. Otherwise, this
1058-
/// indicates that an RPC server is expected to be run on this stream.
1059-
RPCServerTy *RPCServer;
1060-
10611053
/// Mutex to protect stream's management.
10621054
mutable std::mutex Mutex;
10631055

@@ -1237,9 +1229,6 @@ struct AMDGPUStreamTy {
12371229
/// Deinitialize the stream's signals.
12381230
Error deinit() { return Plugin::success(); }
12391231

1240-
/// Attach an RPC server to this stream.
1241-
void setRPCServer(RPCServerTy *Server) { RPCServer = Server; }
1242-
12431232
/// Push a asynchronous kernel to the stream. The kernel arguments must be
12441233
/// placed in a special allocation for kernel args and must keep alive until
12451234
/// the kernel finalizes. Once the kernel is finished, the stream will release
@@ -1267,10 +1256,30 @@ struct AMDGPUStreamTy {
12671256
if (auto Err = Slots[Curr].schedReleaseBuffer(KernelArgs, MemoryManager))
12681257
return Err;
12691258

1259+
// If we are running an RPC server we want to wake up the server thread
1260+
// whenever there is a kernel running and let it sleep otherwise.
1261+
if (Device.getRPCServer())
1262+
Device.Plugin.getRPCServer().Thread->notify();
1263+
12701264
// Push the kernel with the output signal and an input signal (optional)
1271-
return Queue->pushKernelLaunch(Kernel, KernelArgs, NumThreads, NumBlocks,
1272-
GroupSize, StackSize, OutputSignal,
1273-
InputSignal);
1265+
if (auto Err = Queue->pushKernelLaunch(Kernel, KernelArgs, NumThreads,
1266+
NumBlocks, GroupSize, StackSize,
1267+
OutputSignal, InputSignal))
1268+
return Err;
1269+
1270+
// Register a callback to indicate when the kernel is complete.
1271+
if (Device.getRPCServer()) {
1272+
if (auto Err = Slots[Curr].schedCallback(
1273+
[](void *Data) -> llvm::Error {
1274+
GenericPluginTy &Plugin =
1275+
*reinterpret_cast<GenericPluginTy *>(Data);
1276+
Plugin.getRPCServer().Thread->finish();
1277+
return Error::success();
1278+
},
1279+
&Device.Plugin))
1280+
return Err;
1281+
}
1282+
return Plugin::success();
12741283
}
12751284

12761285
/// Push an asynchronous memory copy between pinned memory buffers.
@@ -1480,8 +1489,8 @@ struct AMDGPUStreamTy {
14801489
return Plugin::success();
14811490

14821491
// Wait until all previous operations on the stream have completed.
1483-
if (auto Err = Slots[last()].Signal->wait(StreamBusyWaitMicroseconds,
1484-
RPCServer, &Device))
1492+
if (auto Err =
1493+
Slots[last()].Signal->wait(StreamBusyWaitMicroseconds, &Device))
14851494
return Err;
14861495

14871496
// Reset the stream and perform all pending post actions.
@@ -3025,7 +3034,7 @@ AMDGPUStreamTy::AMDGPUStreamTy(AMDGPUDeviceTy &Device)
30253034
: Agent(Device.getAgent()), Queue(nullptr),
30263035
SignalManager(Device.getSignalManager()), Device(Device),
30273036
// Initialize the std::deque with some empty positions.
3028-
Slots(32), NextSlot(0), SyncCycle(0), RPCServer(nullptr),
3037+
Slots(32), NextSlot(0), SyncCycle(0),
30293038
StreamBusyWaitMicroseconds(Device.getStreamBusyWaitMicroseconds()),
30303039
UseMultipleSdmaEngines(Device.useMultipleSdmaEngines()) {}
30313040

@@ -3378,10 +3387,6 @@ Error AMDGPUKernelTy::launchImpl(GenericDeviceTy &GenericDevice,
33783387
if (auto Err = AMDGPUDevice.getStream(AsyncInfoWrapper, Stream))
33793388
return Err;
33803389

3381-
// If this kernel requires an RPC server we attach its pointer to the stream.
3382-
if (GenericDevice.getRPCServer())
3383-
Stream->setRPCServer(GenericDevice.getRPCServer());
3384-
33853390
// Only COV5 implicitargs needs to be set. COV4 implicitargs are not used.
33863391
if (ImplArgs &&
33873392
getImplicitArgsSize() == sizeof(hsa_utils::AMDGPUImplicitArgsTy)) {

offload/plugins-nextgen/common/include/RPC.h

Lines changed: 67 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,11 @@
1919
#include "llvm/ADT/DenseMap.h"
2020
#include "llvm/Support/Error.h"
2121

22+
#include <atomic>
23+
#include <condition_variable>
2224
#include <cstdint>
25+
#include <mutex>
26+
#include <thread>
2327

2428
namespace llvm::omp::target {
2529
namespace plugin {
@@ -37,6 +41,12 @@ struct RPCServerTy {
3741
/// Initializes the handles to the number of devices we may need to service.
3842
RPCServerTy(plugin::GenericPluginTy &Plugin);
3943

44+
/// Deinitialize the associated memory and resources.
45+
llvm::Error shutDown();
46+
47+
/// Initialize the worker thread.
48+
llvm::Error startThread();
49+
4050
/// Check if this device image is using an RPC server. This checks for the
4151
/// precense of an externally visible symbol in the device image that will
4252
/// be present whenever RPC code is called.
@@ -51,17 +61,69 @@ struct RPCServerTy {
5161
plugin::GenericGlobalHandlerTy &Handler,
5262
plugin::DeviceImageTy &Image);
5363

54-
/// Runs the RPC server associated with the \p Device until the pending work
55-
/// is cleared.
56-
llvm::Error runServer(plugin::GenericDeviceTy &Device);
57-
5864
/// Deinitialize the RPC server for the given device. This will free the
5965
/// memory associated with the k
6066
llvm::Error deinitDevice(plugin::GenericDeviceTy &Device);
6167

6268
private:
6369
/// Array from this device's identifier to its attached devices.
64-
llvm::SmallVector<uintptr_t> Handles;
70+
std::unique_ptr<std::atomic<uintptr_t>[]> Handles;
71+
72+
/// A helper class for running the user thread that handles the RPC interface.
73+
/// Because we only need to check the RPC server while any kernels are
74+
/// working, we track submission / completion events to allow the thread to
75+
/// sleep when it is not needed.
76+
struct ServerThread {
77+
std::thread Worker;
78+
79+
/// A boolean indicating whether or not the worker thread should continue.
80+
std::atomic<bool> Running;
81+
82+
/// The number of currently executing kernels across all devices that need
83+
/// the server thread to be running.
84+
std::atomic<uint32_t> NumUsers;
85+
86+
/// The condition variable used to suspend the thread if no work is needed.
87+
std::condition_variable CV;
88+
std::mutex Mutex;
89+
90+
/// A reference to all the RPC interfaces that the server is handling.
91+
llvm::ArrayRef<std::atomic<uintptr_t>> Handles;
92+
93+
/// Initialize the worker thread to run in the background.
94+
ServerThread(std::atomic<uintptr_t> Handles[], size_t Length)
95+
: Running(true), NumUsers(0), CV(), Mutex(), Handles(Handles, Length) {}
96+
97+
~ServerThread() { assert(!Running && "Thread not shut down explicitly\n"); }
98+
99+
/// Notify the worker thread that there is a user that needs it.
100+
void notify() {
101+
std::lock_guard<decltype(Mutex)> Lock(Mutex);
102+
NumUsers.fetch_add(1, std::memory_order_relaxed);
103+
CV.notify_all();
104+
}
105+
106+
/// Indicate that one of the dependent users has finished.
107+
void finish() {
108+
[[maybe_unused]] uint32_t Old =
109+
NumUsers.fetch_sub(1, std::memory_order_relaxed);
110+
assert(Old > 0 && "Attempt to signal finish with no pending work");
111+
}
112+
113+
/// Destroy the worker thread and wait.
114+
void shutDown();
115+
116+
/// Initialize the worker thread.
117+
void startThread();
118+
119+
/// Run the server thread to continuously check the RPC interface for work
120+
/// to be done for the device.
121+
void run();
122+
};
123+
124+
public:
125+
/// Pointer to the server thread instance.
126+
std::unique_ptr<ServerThread> Thread;
65127
};
66128

67129
} // namespace llvm::omp::target

offload/plugins-nextgen/common/src/PluginInterface.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1051,6 +1051,9 @@ Error GenericDeviceTy::setupRPCServer(GenericPluginTy &Plugin,
10511051
if (auto Err = Server.initDevice(*this, Plugin.getGlobalHandler(), Image))
10521052
return Err;
10531053

1054+
if (auto Err = Server.startThread())
1055+
return Err;
1056+
10541057
RPCServer = &Server;
10551058
DP("Running an RPC server on device %d\n", getDeviceId());
10561059
return Plugin::success();
@@ -1624,8 +1627,11 @@ Error GenericPluginTy::deinit() {
16241627
if (GlobalHandler)
16251628
delete GlobalHandler;
16261629

1627-
if (RPCServer)
1630+
if (RPCServer) {
1631+
if (Error Err = RPCServer->shutDown())
1632+
return Err;
16281633
delete RPCServer;
1634+
}
16291635

16301636
if (RecordReplay)
16311637
delete RecordReplay;

offload/plugins-nextgen/common/src/RPC.cpp

Lines changed: 62 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,69 @@ using namespace llvm;
2121
using namespace omp;
2222
using namespace target;
2323

24+
void RPCServerTy::ServerThread::startThread() {
25+
#ifdef LIBOMPTARGET_RPC_SUPPORT
26+
Worker = std::thread([this]() { run(); });
27+
#endif
28+
}
29+
30+
void RPCServerTy::ServerThread::shutDown() {
31+
#ifdef LIBOMPTARGET_RPC_SUPPORT
32+
{
33+
std::lock_guard<decltype(Mutex)> Lock(Mutex);
34+
Running.store(false, std::memory_order_release);
35+
CV.notify_all();
36+
}
37+
if (Worker.joinable())
38+
Worker.join();
39+
#endif
40+
}
41+
42+
void RPCServerTy::ServerThread::run() {
43+
#ifdef LIBOMPTARGET_RPC_SUPPORT
44+
std::unique_lock<decltype(Mutex)> Lock(Mutex);
45+
for (;;) {
46+
CV.wait(Lock, [&]() {
47+
return NumUsers.load(std::memory_order_acquire) > 0 ||
48+
!Running.load(std::memory_order_acquire);
49+
});
50+
51+
if (!Running.load(std::memory_order_acquire))
52+
return;
53+
54+
Lock.unlock();
55+
while (NumUsers.load(std::memory_order_relaxed) > 0 &&
56+
Running.load(std::memory_order_relaxed)) {
57+
for (const auto &Handle : Handles) {
58+
rpc_device_t RPCDevice{Handle};
59+
[[maybe_unused]] rpc_status_t Err = rpc_handle_server(RPCDevice);
60+
assert(Err == RPC_STATUS_SUCCESS &&
61+
"Checking the RPC server should not fail");
62+
}
63+
}
64+
Lock.lock();
65+
}
66+
#endif
67+
}
68+
2469
RPCServerTy::RPCServerTy(plugin::GenericPluginTy &Plugin)
25-
: Handles(Plugin.getNumDevices()) {}
70+
: Handles(
71+
std::make_unique<std::atomic<uintptr_t>[]>(Plugin.getNumDevices())),
72+
Thread(new ServerThread(Handles.get(), Plugin.getNumDevices())) {}
73+
74+
llvm::Error RPCServerTy::startThread() {
75+
#ifdef LIBOMPTARGET_RPC_SUPPORT
76+
Thread->startThread();
77+
#endif
78+
return Error::success();
79+
}
80+
81+
llvm::Error RPCServerTy::shutDown() {
82+
#ifdef LIBOMPTARGET_RPC_SUPPORT
83+
Thread->shutDown();
84+
#endif
85+
return Error::success();
86+
}
2687

2788
llvm::Expected<bool>
2889
RPCServerTy::isDeviceUsingRPC(plugin::GenericDeviceTy &Device,
@@ -109,17 +170,6 @@ Error RPCServerTy::initDevice(plugin::GenericDeviceTy &Device,
109170
return Error::success();
110171
}
111172

112-
Error RPCServerTy::runServer(plugin::GenericDeviceTy &Device) {
113-
#ifdef LIBOMPTARGET_RPC_SUPPORT
114-
rpc_device_t RPCDevice{Handles[Device.getDeviceId()]};
115-
if (rpc_status_t Err = rpc_handle_server(RPCDevice))
116-
return plugin::Plugin::error(
117-
"Error while running RPC server on device %d: %d", Device.getDeviceId(),
118-
Err);
119-
#endif
120-
return Error::success();
121-
}
122-
123173
Error RPCServerTy::deinitDevice(plugin::GenericDeviceTy &Device) {
124174
#ifdef LIBOMPTARGET_RPC_SUPPORT
125175
rpc_device_t RPCDevice{Handles[Device.getDeviceId()]};

offload/plugins-nextgen/cuda/dynamic_cuda/cuda.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ DLWRAP(cuStreamCreate, 2)
6363
DLWRAP(cuStreamDestroy, 1)
6464
DLWRAP(cuStreamSynchronize, 1)
6565
DLWRAP(cuStreamQuery, 1)
66+
DLWRAP(cuStreamAddCallback, 4)
6667
DLWRAP(cuCtxSetCurrent, 1)
6768
DLWRAP(cuDevicePrimaryCtxRelease, 1)
6869
DLWRAP(cuDevicePrimaryCtxGetState, 3)

offload/plugins-nextgen/cuda/dynamic_cuda/cuda.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,8 @@ static inline void *CU_LAUNCH_PARAM_END = (void *)0x00;
286286
static inline void *CU_LAUNCH_PARAM_BUFFER_POINTER = (void *)0x01;
287287
static inline void *CU_LAUNCH_PARAM_BUFFER_SIZE = (void *)0x02;
288288

289+
typedef void (*CUstreamCallback)(CUstream, CUresult, void *);
290+
289291
CUresult cuCtxGetDevice(CUdevice *);
290292
CUresult cuDeviceGet(CUdevice *, int);
291293
CUresult cuDeviceGetAttribute(int *, CUdevice_attribute, CUdevice);
@@ -326,6 +328,7 @@ CUresult cuStreamCreate(CUstream *, unsigned);
326328
CUresult cuStreamDestroy(CUstream);
327329
CUresult cuStreamSynchronize(CUstream);
328330
CUresult cuStreamQuery(CUstream);
331+
CUresult cuStreamAddCallback(CUstream, CUstreamCallback, void *, unsigned int);
329332
CUresult cuCtxSetCurrent(CUcontext);
330333
CUresult cuDevicePrimaryCtxRelease(CUdevice);
331334
CUresult cuDevicePrimaryCtxGetState(CUdevice, unsigned *, int *);

0 commit comments

Comments
 (0)