Skip to content

Commit d103317

Browse files
committed
[Offload] Move RPC server handling to a dedicated thread
Summary: Handling the RPC server requires running through list of jobs that the device has requested to be done. Currently this is handled by the thread that does the waiting for the kernel to finish. However, this is not sound on NVIDIA architectures and only works for async launches in the OpenMP model that uses helper threads. However, we also don't want to have this thread doing work unnnecessarily. For this reason we track the execution of kernels and cause the thread to sleep via a condition variable (usually backed by some kind of futex or other intelligent sleeping mechanism) so that the thread will be idle while no kernels are running.
1 parent 3b2320b commit d103317

File tree

8 files changed

+226
-65
lines changed

8 files changed

+226
-65
lines changed

offload/plugins-nextgen/amdgpu/src/rtl.cpp

Lines changed: 32 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -626,9 +626,9 @@ struct AMDGPUSignalTy {
626626
}
627627

628628
/// Wait until the signal gets a zero value.
629-
Error wait(const uint64_t ActiveTimeout = 0, RPCServerTy *RPCServer = nullptr,
629+
Error wait(const uint64_t ActiveTimeout = 0,
630630
GenericDeviceTy *Device = nullptr) const {
631-
if (ActiveTimeout && !RPCServer) {
631+
if (ActiveTimeout) {
632632
hsa_signal_value_t Got = 1;
633633
Got = hsa_signal_wait_scacquire(HSASignal, HSA_SIGNAL_CONDITION_EQ, 0,
634634
ActiveTimeout, HSA_WAIT_STATE_ACTIVE);
@@ -637,14 +637,11 @@ struct AMDGPUSignalTy {
637637
}
638638

639639
// If there is an RPC device attached to this stream we run it as a server.
640-
uint64_t Timeout = RPCServer ? 8192 : UINT64_MAX;
641-
auto WaitState = RPCServer ? HSA_WAIT_STATE_ACTIVE : HSA_WAIT_STATE_BLOCKED;
640+
uint64_t Timeout = UINT64_MAX;
641+
auto WaitState = HSA_WAIT_STATE_BLOCKED;
642642
while (hsa_signal_wait_scacquire(HSASignal, HSA_SIGNAL_CONDITION_EQ, 0,
643-
Timeout, WaitState) != 0) {
644-
if (RPCServer && Device)
645-
if (auto Err = RPCServer->runServer(*Device))
646-
return Err;
647-
}
643+
Timeout, WaitState) != 0)
644+
;
648645
return Plugin::success();
649646
}
650647

@@ -1052,11 +1049,6 @@ struct AMDGPUStreamTy {
10521049
/// operation that was already finalized in a previous stream sycnhronize.
10531050
uint32_t SyncCycle;
10541051

1055-
/// A pointer associated with an RPC server running on the given device. If
1056-
/// RPC is not being used this will be a null pointer. Otherwise, this
1057-
/// indicates that an RPC server is expected to be run on this stream.
1058-
RPCServerTy *RPCServer;
1059-
10601052
/// Mutex to protect stream's management.
10611053
mutable std::mutex Mutex;
10621054

@@ -1236,9 +1228,6 @@ struct AMDGPUStreamTy {
12361228
/// Deinitialize the stream's signals.
12371229
Error deinit() { return Plugin::success(); }
12381230

1239-
/// Attach an RPC server to this stream.
1240-
void setRPCServer(RPCServerTy *Server) { RPCServer = Server; }
1241-
12421231
/// Push a asynchronous kernel to the stream. The kernel arguments must be
12431232
/// placed in a special allocation for kernel args and must keep alive until
12441233
/// the kernel finalizes. Once the kernel is finished, the stream will release
@@ -1266,10 +1255,30 @@ struct AMDGPUStreamTy {
12661255
if (auto Err = Slots[Curr].schedReleaseBuffer(KernelArgs, MemoryManager))
12671256
return Err;
12681257

1258+
// If we are running an RPC server we want to wake up the server thread
1259+
// whenever there is a kernel running and let it sleep otherwise.
1260+
if (Device.getRPCServer())
1261+
Device.Plugin.getRPCServer().Thread->notify();
1262+
12691263
// Push the kernel with the output signal and an input signal (optional)
1270-
return Queue->pushKernelLaunch(Kernel, KernelArgs, NumThreads, NumBlocks,
1271-
GroupSize, StackSize, OutputSignal,
1272-
InputSignal);
1264+
if (auto Err = Queue->pushKernelLaunch(Kernel, KernelArgs, NumThreads,
1265+
NumBlocks, GroupSize, StackSize,
1266+
OutputSignal, InputSignal))
1267+
return Err;
1268+
1269+
// Register a callback to indicate when the kernel is complete.
1270+
if (Device.getRPCServer()) {
1271+
if (auto Err = Slots[Curr].schedCallback(
1272+
[](void *Data) -> llvm::Error {
1273+
GenericPluginTy &Plugin =
1274+
*reinterpret_cast<GenericPluginTy *>(Data);
1275+
Plugin.getRPCServer().Thread->finish();
1276+
return Error::success();
1277+
},
1278+
&Device.Plugin))
1279+
return Err;
1280+
}
1281+
return Plugin::success();
12731282
}
12741283

12751284
/// Push an asynchronous memory copy between pinned memory buffers.
@@ -1479,8 +1488,8 @@ struct AMDGPUStreamTy {
14791488
return Plugin::success();
14801489

14811490
// Wait until all previous operations on the stream have completed.
1482-
if (auto Err = Slots[last()].Signal->wait(StreamBusyWaitMicroseconds,
1483-
RPCServer, &Device))
1491+
if (auto Err =
1492+
Slots[last()].Signal->wait(StreamBusyWaitMicroseconds, &Device))
14841493
return Err;
14851494

14861495
// Reset the stream and perform all pending post actions.
@@ -3024,7 +3033,7 @@ AMDGPUStreamTy::AMDGPUStreamTy(AMDGPUDeviceTy &Device)
30243033
: Agent(Device.getAgent()), Queue(nullptr),
30253034
SignalManager(Device.getSignalManager()), Device(Device),
30263035
// Initialize the std::deque with some empty positions.
3027-
Slots(32), NextSlot(0), SyncCycle(0), RPCServer(nullptr),
3036+
Slots(32), NextSlot(0), SyncCycle(0),
30283037
StreamBusyWaitMicroseconds(Device.getStreamBusyWaitMicroseconds()),
30293038
UseMultipleSdmaEngines(Device.useMultipleSdmaEngines()) {}
30303039

@@ -3377,10 +3386,6 @@ Error AMDGPUKernelTy::launchImpl(GenericDeviceTy &GenericDevice,
33773386
if (auto Err = AMDGPUDevice.getStream(AsyncInfoWrapper, Stream))
33783387
return Err;
33793388

3380-
// If this kernel requires an RPC server we attach its pointer to the stream.
3381-
if (GenericDevice.getRPCServer())
3382-
Stream->setRPCServer(GenericDevice.getRPCServer());
3383-
33843389
// Only COV5 implicitargs needs to be set. COV4 implicitargs are not used.
33853390
if (ImplArgs &&
33863391
getImplicitArgsSize() == sizeof(hsa_utils::AMDGPUImplicitArgsTy)) {

offload/plugins-nextgen/common/include/RPC.h

Lines changed: 56 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,11 @@
1919
#include "llvm/ADT/DenseMap.h"
2020
#include "llvm/Support/Error.h"
2121

22+
#include <atomic>
23+
#include <condition_variable>
2224
#include <cstdint>
25+
#include <mutex>
26+
#include <thread>
2327

2428
namespace llvm::omp::target {
2529
namespace plugin {
@@ -37,6 +41,9 @@ struct RPCServerTy {
3741
/// Initializes the handles to the number of devices we may need to service.
3842
RPCServerTy(plugin::GenericPluginTy &Plugin);
3943

44+
/// Deinitialize the associated memory and resources.
45+
llvm::Error shutDown();
46+
4047
/// Check if this device image is using an RPC server. This checks for the
4148
/// precense of an externally visible symbol in the device image that will
4249
/// be present whenever RPC code is called.
@@ -51,17 +58,61 @@ struct RPCServerTy {
5158
plugin::GenericGlobalHandlerTy &Handler,
5259
plugin::DeviceImageTy &Image);
5360

54-
/// Runs the RPC server associated with the \p Device until the pending work
55-
/// is cleared.
56-
llvm::Error runServer(plugin::GenericDeviceTy &Device);
57-
5861
/// Deinitialize the RPC server for the given device. This will free the
5962
/// memory associated with the k
6063
llvm::Error deinitDevice(plugin::GenericDeviceTy &Device);
6164

6265
private:
6366
/// Array from this device's identifier to its attached devices.
64-
llvm::SmallVector<uintptr_t> Handles;
67+
std::unique_ptr<std::atomic<uintptr_t>[]> Handles;
68+
69+
/// A helper class for running the user thread that handles
70+
struct ServerThread {
71+
std::thread Worker;
72+
73+
/// A boolean indicating whether or not the worker thread should continue.
74+
std::atomic<bool> Running;
75+
76+
/// The number of currently executing kernels across all devices that need
77+
/// the server thread to be running.
78+
std::atomic<uint32_t> NumUsers;
79+
80+
/// The condition variable used to suspend the thread if no work is needed.
81+
std::condition_variable CV;
82+
std::mutex Mutex;
83+
84+
/// A reference to all the RPC interfaces that the server is handling.
85+
llvm::ArrayRef<std::atomic<uintptr_t>> Handles;
86+
87+
/// Initialize the worker thread to run in the background.
88+
ServerThread(std::atomic<uintptr_t> Handles[], size_t Length);
89+
~ServerThread() { assert(!Running && "Thread not shut down explicitly\n"); }
90+
91+
/// Notify the worker thread that there is a user that needs it.
92+
void notify() {
93+
std::lock_guard<decltype(Mutex)> Lock(Mutex);
94+
NumUsers.fetch_add(1, std::memory_order_relaxed);
95+
CV.notify_all();
96+
}
97+
98+
/// Indicate that one of the dependent users has finished.
99+
void finish() {
100+
[[maybe_unused]] uint32_t Old =
101+
NumUsers.fetch_sub(1, std::memory_order_relaxed);
102+
assert(Old > 0 && "Attempt to signal finish with no pending work");
103+
}
104+
105+
/// Destroy the worker thread and wait.
106+
void shutDown();
107+
108+
/// Run the server thread to continuously check the RPC interface for work
109+
/// to be done for the device.
110+
void run();
111+
};
112+
113+
public:
114+
/// Pointer to the server thread instance.
115+
std::unique_ptr<ServerThread> Thread;
65116
};
66117

67118
} // namespace llvm::omp::target

offload/plugins-nextgen/common/src/PluginInterface.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1624,8 +1624,11 @@ Error GenericPluginTy::deinit() {
16241624
if (GlobalHandler)
16251625
delete GlobalHandler;
16261626

1627-
if (RPCServer)
1627+
if (RPCServer) {
1628+
if (Error Err = RPCServer->shutDown())
1629+
return Err;
16281630
delete RPCServer;
1631+
}
16291632

16301633
if (RecordReplay)
16311634
delete RecordReplay;

offload/plugins-nextgen/common/src/RPC.cpp

Lines changed: 57 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,64 @@ using namespace llvm;
2121
using namespace omp;
2222
using namespace target;
2323

24+
void RPCServerTy::ServerThread::shutDown() {
25+
#ifdef LIBOMPTARGET_RPC_SUPPORT
26+
{
27+
std::lock_guard<decltype(Mutex)> Lock(Mutex);
28+
Running.store(false, std::memory_order_release);
29+
CV.notify_all();
30+
}
31+
if (Worker.joinable())
32+
Worker.join();
33+
#endif
34+
}
35+
36+
void RPCServerTy::ServerThread::run() {
37+
#ifdef LIBOMPTARGET_RPC_SUPPORT
38+
for (;;) {
39+
std::unique_lock<decltype(Mutex)> Lock(Mutex);
40+
CV.wait(Lock, [&]() {
41+
return NumUsers.load(std::memory_order_acquire) > 0 ||
42+
!Running.load(std::memory_order_acquire);
43+
});
44+
45+
if (!Running.load(std::memory_order_acq_rel))
46+
return;
47+
48+
while (NumUsers.load(std::memory_order_relaxed) > 0 &&
49+
Running.load(std::memory_order_relaxed)) {
50+
Lock.unlock();
51+
for (const auto &Handle : Handles) {
52+
rpc_device_t RPCDevice{Handle};
53+
[[maybe_unused]] rpc_status_t Err = rpc_handle_server(RPCDevice);
54+
assert(Err == RPC_STATUS_SUCCESS &&
55+
"Checking the RPC server should not fail");
56+
}
57+
Lock.lock();
58+
}
59+
}
60+
#endif
61+
}
62+
63+
RPCServerTy::ServerThread::ServerThread(std::atomic<uintptr_t> Handles[],
64+
size_t Length)
65+
: Running(true), NumUsers(0), CV(), Mutex(), Handles(Handles, Length) {
66+
#ifdef LIBOMPTARGET_RPC_SUPPORT
67+
Worker = std::thread([this]() { run(); });
68+
#endif
69+
}
70+
2471
RPCServerTy::RPCServerTy(plugin::GenericPluginTy &Plugin)
25-
: Handles(Plugin.getNumDevices()) {}
72+
: Handles(
73+
std::make_unique<std::atomic<uintptr_t>[]>(Plugin.getNumDevices())),
74+
Thread(new ServerThread(Handles.get(), Plugin.getNumDevices())) {}
75+
76+
llvm::Error RPCServerTy::shutDown() {
77+
#ifdef LIBOMPTARGET_RPC_SUPPORT
78+
Thread->shutDown();
79+
#endif
80+
return Error::success();
81+
}
2682

2783
llvm::Expected<bool>
2884
RPCServerTy::isDeviceUsingRPC(plugin::GenericDeviceTy &Device,
@@ -109,17 +165,6 @@ Error RPCServerTy::initDevice(plugin::GenericDeviceTy &Device,
109165
return Error::success();
110166
}
111167

112-
Error RPCServerTy::runServer(plugin::GenericDeviceTy &Device) {
113-
#ifdef LIBOMPTARGET_RPC_SUPPORT
114-
rpc_device_t RPCDevice{Handles[Device.getDeviceId()]};
115-
if (rpc_status_t Err = rpc_handle_server(RPCDevice))
116-
return plugin::Plugin::error(
117-
"Error while running RPC server on device %d: %d", Device.getDeviceId(),
118-
Err);
119-
#endif
120-
return Error::success();
121-
}
122-
123168
Error RPCServerTy::deinitDevice(plugin::GenericDeviceTy &Device) {
124169
#ifdef LIBOMPTARGET_RPC_SUPPORT
125170
rpc_device_t RPCDevice{Handles[Device.getDeviceId()]};

offload/plugins-nextgen/cuda/dynamic_cuda/cuda.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ DLWRAP(cuStreamCreate, 2)
6363
DLWRAP(cuStreamDestroy, 1)
6464
DLWRAP(cuStreamSynchronize, 1)
6565
DLWRAP(cuStreamQuery, 1)
66+
DLWRAP(cuStreamAddCallback, 4)
6667
DLWRAP(cuCtxSetCurrent, 1)
6768
DLWRAP(cuDevicePrimaryCtxRelease, 1)
6869
DLWRAP(cuDevicePrimaryCtxGetState, 3)

offload/plugins-nextgen/cuda/dynamic_cuda/cuda.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,8 @@ static inline void *CU_LAUNCH_PARAM_END = (void *)0x00;
286286
static inline void *CU_LAUNCH_PARAM_BUFFER_POINTER = (void *)0x01;
287287
static inline void *CU_LAUNCH_PARAM_BUFFER_SIZE = (void *)0x02;
288288

289+
typedef void (*CUstreamCallback)(CUstream, CUresult, void *);
290+
289291
CUresult cuCtxGetDevice(CUdevice *);
290292
CUresult cuDeviceGet(CUdevice *, int);
291293
CUresult cuDeviceGetAttribute(int *, CUdevice_attribute, CUdevice);
@@ -326,6 +328,7 @@ CUresult cuStreamCreate(CUstream *, unsigned);
326328
CUresult cuStreamDestroy(CUstream);
327329
CUresult cuStreamSynchronize(CUstream);
328330
CUresult cuStreamQuery(CUstream);
331+
CUresult cuStreamAddCallback(CUstream, CUstreamCallback, void *, unsigned int);
329332
CUresult cuCtxSetCurrent(CUcontext);
330333
CUresult cuDevicePrimaryCtxRelease(CUdevice);
331334
CUresult cuDevicePrimaryCtxGetState(CUdevice, unsigned *, int *);

offload/plugins-nextgen/cuda/src/rtl.cpp

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -632,15 +632,7 @@ struct CUDADeviceTy : public GenericDeviceTy {
632632
CUresult Res;
633633
// If we have an RPC server running on this device we will continuously
634634
// query it for work rather than blocking.
635-
if (!getRPCServer()) {
636-
Res = cuStreamSynchronize(Stream);
637-
} else {
638-
do {
639-
Res = cuStreamQuery(Stream);
640-
if (auto Err = getRPCServer()->runServer(*this))
641-
return Err;
642-
} while (Res == CUDA_ERROR_NOT_READY);
643-
}
635+
Res = cuStreamSynchronize(Stream);
644636

645637
// Once the stream is synchronized, return it to stream pool and reset
646638
// AsyncInfo. This is to make sure the synchronization only works for its
@@ -825,17 +817,6 @@ struct CUDADeviceTy : public GenericDeviceTy {
825817
if (auto Err = getStream(AsyncInfoWrapper, Stream))
826818
return Err;
827819

828-
// If there is already pending work on the stream it could be waiting for
829-
// someone to check the RPC server.
830-
if (auto *RPCServer = getRPCServer()) {
831-
CUresult Res = cuStreamQuery(Stream);
832-
while (Res == CUDA_ERROR_NOT_READY) {
833-
if (auto Err = RPCServer->runServer(*this))
834-
return Err;
835-
Res = cuStreamQuery(Stream);
836-
}
837-
}
838-
839820
CUresult Res = cuMemcpyDtoHAsync(HstPtr, (CUdeviceptr)TgtPtr, Size, Stream);
840821
return Plugin::check(Res, "Error in cuMemcpyDtoHAsync: %s");
841822
}
@@ -1294,10 +1275,26 @@ Error CUDAKernelTy::launchImpl(GenericDeviceTy &GenericDevice,
12941275
reinterpret_cast<void *>(&LaunchParams.Size),
12951276
CU_LAUNCH_PARAM_END};
12961277

1278+
// If we are running an RPC server we want to wake up the server thread
1279+
// whenever there is a kernel running and let it sleep otherwise.
1280+
if (GenericDevice.getRPCServer())
1281+
GenericDevice.Plugin.getRPCServer().Thread->notify();
1282+
12971283
CUresult Res = cuLaunchKernel(Func, NumBlocks, /*gridDimY=*/1,
12981284
/*gridDimZ=*/1, NumThreads,
12991285
/*blockDimY=*/1, /*blockDimZ=*/1,
13001286
MaxDynCGroupMem, Stream, nullptr, Config);
1287+
1288+
// Register a callback to indicate when the kernel is complete.
1289+
if (GenericDevice.getRPCServer())
1290+
cuStreamAddCallback(
1291+
Stream,
1292+
[](CUstream Stream, CUresult Status, void *Data) {
1293+
GenericPluginTy &Plugin = *reinterpret_cast<GenericPluginTy *>(Data);
1294+
Plugin.getRPCServer().Thread->finish();
1295+
},
1296+
&GenericDevice.Plugin, /*flags=*/0);
1297+
13011298
return Plugin::check(Res, "Error in cuLaunchKernel for '%s': %s", getName());
13021299
}
13031300

0 commit comments

Comments
 (0)