Skip to content

[Offload] Move RPC server handling to a dedicated thread #112988

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 32 additions & 27 deletions offload/plugins-nextgen/amdgpu/src/rtl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -621,9 +621,9 @@ struct AMDGPUSignalTy {
}

/// Wait until the signal gets a zero value.
Error wait(const uint64_t ActiveTimeout = 0, RPCServerTy *RPCServer = nullptr,
Error wait(const uint64_t ActiveTimeout = 0,
GenericDeviceTy *Device = nullptr) const {
if (ActiveTimeout && !RPCServer) {
if (ActiveTimeout) {
hsa_signal_value_t Got = 1;
Got = hsa_signal_wait_scacquire(HSASignal, HSA_SIGNAL_CONDITION_EQ, 0,
ActiveTimeout, HSA_WAIT_STATE_ACTIVE);
Expand All @@ -632,14 +632,11 @@ struct AMDGPUSignalTy {
}

// If there is an RPC device attached to this stream we run it as a server.
uint64_t Timeout = RPCServer ? 8192 : UINT64_MAX;
auto WaitState = RPCServer ? HSA_WAIT_STATE_ACTIVE : HSA_WAIT_STATE_BLOCKED;
uint64_t Timeout = UINT64_MAX;
auto WaitState = HSA_WAIT_STATE_BLOCKED;
while (hsa_signal_wait_scacquire(HSASignal, HSA_SIGNAL_CONDITION_EQ, 0,
Timeout, WaitState) != 0) {
if (RPCServer && Device)
if (auto Err = RPCServer->runServer(*Device))
return Err;
}
Timeout, WaitState) != 0)
;
return Plugin::success();
}

Expand Down Expand Up @@ -1052,11 +1049,6 @@ struct AMDGPUStreamTy {
/// operation that was already finalized in a previous stream sycnhronize.
uint32_t SyncCycle;

/// A pointer associated with an RPC server running on the given device. If
/// RPC is not being used this will be a null pointer. Otherwise, this
/// indicates that an RPC server is expected to be run on this stream.
RPCServerTy *RPCServer;

/// Mutex to protect stream's management.
mutable std::mutex Mutex;

Expand Down Expand Up @@ -1236,9 +1228,6 @@ struct AMDGPUStreamTy {
/// Deinitialize the stream's signals.
Error deinit() { return Plugin::success(); }

/// Attach an RPC server to this stream.
void setRPCServer(RPCServerTy *Server) { RPCServer = Server; }

/// Push a asynchronous kernel to the stream. The kernel arguments must be
/// placed in a special allocation for kernel args and must keep alive until
/// the kernel finalizes. Once the kernel is finished, the stream will release
Expand Down Expand Up @@ -1266,10 +1255,30 @@ struct AMDGPUStreamTy {
if (auto Err = Slots[Curr].schedReleaseBuffer(KernelArgs, MemoryManager))
return Err;

// If we are running an RPC server we want to wake up the server thread
// whenever there is a kernel running and let it sleep otherwise.
if (Device.getRPCServer())
Device.Plugin.getRPCServer().Thread->notify();

// Push the kernel with the output signal and an input signal (optional)
return Queue->pushKernelLaunch(Kernel, KernelArgs, NumThreads, NumBlocks,
GroupSize, StackSize, OutputSignal,
InputSignal);
if (auto Err = Queue->pushKernelLaunch(Kernel, KernelArgs, NumThreads,
NumBlocks, GroupSize, StackSize,
OutputSignal, InputSignal))
return Err;

// Register a callback to indicate when the kernel is complete.
if (Device.getRPCServer()) {
if (auto Err = Slots[Curr].schedCallback(
[](void *Data) -> llvm::Error {
GenericPluginTy &Plugin =
*reinterpret_cast<GenericPluginTy *>(Data);
Plugin.getRPCServer().Thread->finish();
return Error::success();
},
&Device.Plugin))
return Err;
}
return Plugin::success();
}

/// Push an asynchronous memory copy between pinned memory buffers.
Expand Down Expand Up @@ -1479,8 +1488,8 @@ struct AMDGPUStreamTy {
return Plugin::success();

// Wait until all previous operations on the stream have completed.
if (auto Err = Slots[last()].Signal->wait(StreamBusyWaitMicroseconds,
RPCServer, &Device))
if (auto Err =
Slots[last()].Signal->wait(StreamBusyWaitMicroseconds, &Device))
return Err;

// Reset the stream and perform all pending post actions.
Expand Down Expand Up @@ -3027,7 +3036,7 @@ AMDGPUStreamTy::AMDGPUStreamTy(AMDGPUDeviceTy &Device)
: Agent(Device.getAgent()), Queue(nullptr),
SignalManager(Device.getSignalManager()), Device(Device),
// Initialize the std::deque with some empty positions.
Slots(32), NextSlot(0), SyncCycle(0), RPCServer(nullptr),
Slots(32), NextSlot(0), SyncCycle(0),
StreamBusyWaitMicroseconds(Device.getStreamBusyWaitMicroseconds()),
UseMultipleSdmaEngines(Device.useMultipleSdmaEngines()) {}

Expand Down Expand Up @@ -3383,10 +3392,6 @@ Error AMDGPUKernelTy::launchImpl(GenericDeviceTy &GenericDevice,
if (auto Err = AMDGPUDevice.getStream(AsyncInfoWrapper, Stream))
return Err;

// If this kernel requires an RPC server we attach its pointer to the stream.
if (GenericDevice.getRPCServer())
Stream->setRPCServer(GenericDevice.getRPCServer());

// Only COV5 implicitargs needs to be set. COV4 implicitargs are not used.
if (ImplArgs &&
getImplicitArgsSize() == sizeof(hsa_utils::AMDGPUImplicitArgsTy)) {
Expand Down
80 changes: 75 additions & 5 deletions offload/plugins-nextgen/common/include/RPC.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@
#include "llvm/ADT/DenseMap.h"
#include "llvm/Support/Error.h"

#include <atomic>
#include <condition_variable>
#include <cstdint>
#include <mutex>
#include <thread>

namespace llvm::omp::target {
namespace plugin {
Expand All @@ -37,6 +41,12 @@ struct RPCServerTy {
/// Initializes the handles to the number of devices we may need to service.
RPCServerTy(plugin::GenericPluginTy &Plugin);

/// Deinitialize the associated memory and resources.
llvm::Error shutDown();

/// Initialize the worker thread.
llvm::Error startThread();

/// Check if this device image is using an RPC server. This checks for the
/// precense of an externally visible symbol in the device image that will
/// be present whenever RPC code is called.
Expand All @@ -51,17 +61,77 @@ struct RPCServerTy {
plugin::GenericGlobalHandlerTy &Handler,
plugin::DeviceImageTy &Image);

/// Runs the RPC server associated with the \p Device until the pending work
/// is cleared.
llvm::Error runServer(plugin::GenericDeviceTy &Device);

/// Deinitialize the RPC server for the given device. This will free the
/// memory associated with the k
llvm::Error deinitDevice(plugin::GenericDeviceTy &Device);

private:
/// Array from this device's identifier to its attached devices.
llvm::SmallVector<void *> Buffers;
std::unique_ptr<void *[]> Buffers;

/// Array of associated devices. These must be alive as long as the server is.
std::unique_ptr<plugin::GenericDeviceTy *[]> Devices;

/// A helper class for running the user thread that handles the RPC interface.
/// Because we only need to check the RPC server while any kernels are
/// working, we track submission / completion events to allow the thread to
/// sleep when it is not needed.
struct ServerThread {
std::thread Worker;

/// A boolean indicating whether or not the worker thread should continue.
std::atomic<bool> Running;

/// The number of currently executing kernels across all devices that need
/// the server thread to be running.
std::atomic<uint32_t> NumUsers;

/// The condition variable used to suspend the thread if no work is needed.
std::condition_variable CV;
std::mutex Mutex;

/// A reference to all the RPC interfaces that the server is handling.
llvm::ArrayRef<void *> Buffers;

/// A reference to the associated generic device for the buffer.
llvm::ArrayRef<plugin::GenericDeviceTy *> Devices;

/// Initialize the worker thread to run in the background.
ServerThread(void *Buffers[], plugin::GenericDeviceTy *Devices[],
size_t Length)
: Running(true), NumUsers(0), CV(), Mutex(), Buffers(Buffers, Length),
Devices(Devices, Length) {}

~ServerThread() { assert(!Running && "Thread not shut down explicitly\n"); }

/// Notify the worker thread that there is a user that needs it.
void notify() {
std::lock_guard<decltype(Mutex)> Lock(Mutex);
NumUsers.fetch_add(1, std::memory_order_relaxed);
CV.notify_all();
}

/// Indicate that one of the dependent users has finished.
void finish() {
[[maybe_unused]] uint32_t Old =
NumUsers.fetch_sub(1, std::memory_order_relaxed);
assert(Old > 0 && "Attempt to signal finish with no pending work");
}

/// Destroy the worker thread and wait.
void shutDown();

/// Initialize the worker thread.
void startThread();

/// Run the server thread to continuously check the RPC interface for work
/// to be done for the device.
void run();
};

public:
/// Pointer to the server thread instance.
std::unique_ptr<ServerThread> Thread;
};

} // namespace llvm::omp::target
Expand Down
8 changes: 7 additions & 1 deletion offload/plugins-nextgen/common/src/PluginInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1057,6 +1057,9 @@ Error GenericDeviceTy::setupRPCServer(GenericPluginTy &Plugin,
if (auto Err = Server.initDevice(*this, Plugin.getGlobalHandler(), Image))
return Err;

if (auto Err = Server.startThread())
return Err;

RPCServer = &Server;
DP("Running an RPC server on device %d\n", getDeviceId());
return Plugin::success();
Expand Down Expand Up @@ -1630,8 +1633,11 @@ Error GenericPluginTy::deinit() {
if (GlobalHandler)
delete GlobalHandler;

if (RPCServer)
if (RPCServer) {
if (Error Err = RPCServer->shutDown())
return Err;
delete RPCServer;
}

if (RecordReplay)
delete RecordReplay;
Expand Down
Loading
Loading