Skip to content

Commit b30565e

Browse files
committed
rpc : enable async operations
Start a dedicated backend thread in the rpc-server and use message passing interface for submitting work to it. This will enable backend async operations and cross-server communication.
1 parent 172c825 commit b30565e

File tree

1 file changed

+139
-33
lines changed

1 file changed

+139
-33
lines changed

ggml-rpc.cpp

Lines changed: 139 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,11 @@
55
#include <cinttypes>
66
#include <string>
77
#include <vector>
8+
#include <queue>
89
#include <memory>
910
#include <mutex>
11+
#include <thread>
12+
#include <condition_variable>
1013
#include <unordered_map>
1114
#include <unordered_set>
1215
#ifdef _WIN32
@@ -17,6 +20,7 @@
1720
# include <windows.h>
1821
# include <winsock2.h>
1922
#else
23+
# include <signal.h>
2024
# include <arpa/inet.h>
2125
# include <sys/socket.h>
2226
# include <sys/types.h>
@@ -89,6 +93,7 @@ enum rpc_cmd {
8993
COPY_TENSOR,
9094
GRAPH_COMPUTE,
9195
GET_DEVICE_MEMORY,
96+
FREE_ALL_BUFFERS,
9297
};
9398

9499
// RPC data structures
@@ -736,6 +741,48 @@ GGML_API GGML_CALL void ggml_backend_rpc_get_device_memory(const char * endpoint
736741

737742
// RPC server-side implementation
738743

744+
template <typename T>
745+
class message_queue {
746+
std::queue<T> queue;
747+
std::mutex mutex;
748+
std::condition_variable cvar;
749+
750+
public:
751+
message_queue() {}
752+
753+
void push(const T &value) {
754+
std::unique_lock<std::mutex> lock(mutex);
755+
queue.push(value);
756+
lock.unlock();
757+
cvar.notify_all();
758+
}
759+
760+
void pop(T* out) {
761+
std::unique_lock<std::mutex> lock(mutex);
762+
cvar.wait(lock, [this] { return queue.size() > 0; });
763+
*out = queue.front();
764+
queue.pop();
765+
}
766+
};
767+
768+
struct rpc_response {
769+
std::vector<uint8_t> output;
770+
bool status;
771+
};
772+
773+
using rpc_response_ptr = std::shared_ptr<rpc_response>;
774+
using response_queue = message_queue<rpc_response_ptr>;
775+
using response_queue_ptr = std::shared_ptr<response_queue>;
776+
777+
struct rpc_request {
778+
rpc_cmd cmd;
779+
std::vector<uint8_t> input;
780+
response_queue_ptr response_queue;
781+
};
782+
using rpc_request_ptr = std::shared_ptr<rpc_request>;
783+
using request_queue = message_queue<rpc_request_ptr>;
784+
using request_queue_ptr = std::shared_ptr<request_queue>;
785+
739786
class rpc_server {
740787
public:
741788
rpc_server(ggml_backend_t backend) : backend(backend) {}
@@ -752,6 +799,7 @@ class rpc_server {
752799
bool copy_tensor(const std::vector<uint8_t> & input, std::vector<uint8_t> & output);
753800
bool graph_compute(const std::vector<uint8_t> & input, std::vector<uint8_t> & output);
754801

802+
void free_all_buffers();
755803
private:
756804
ggml_tensor * deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor);
757805
ggml_tensor * create_node(uint64_t id,
@@ -1046,76 +1094,122 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input, std::vector<u
10461094
return true;
10471095
}
10481096

1049-
rpc_server::~rpc_server() {
1097+
void rpc_server::free_all_buffers() {
10501098
for (auto buffer : buffers) {
10511099
ggml_backend_buffer_free(buffer);
10521100
}
1101+
buffers.clear();
1102+
}
1103+
1104+
rpc_server::~rpc_server() {
1105+
free_all_buffers();
10531106
}
10541107

1055-
static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t free_mem, size_t total_mem) {
1108+
static void process_requests(ggml_backend_t backend, request_queue_ptr requestq) {
10561109
rpc_server server(backend);
10571110
while (true) {
1058-
uint8_t cmd;
1059-
if (!recv_data(sockfd, &cmd, 1)) {
1060-
break;
1061-
}
1062-
std::vector<uint8_t> input;
1063-
std::vector<uint8_t> output;
1064-
uint64_t input_size;
1065-
if (!recv_data(sockfd, &input_size, sizeof(input_size))) {
1066-
break;
1067-
}
1068-
input.resize(input_size);
1069-
if (!recv_data(sockfd, input.data(), input_size)) {
1070-
break;
1071-
}
1111+
rpc_request_ptr request;
1112+
requestq->pop(&request);
1113+
rpc_response_ptr response = std::make_shared<rpc_response>();
10721114
bool ok = true;
1073-
switch (cmd) {
1115+
switch (request->cmd) {
10741116
case ALLOC_BUFFER: {
1075-
ok = server.alloc_buffer(input, output);
1117+
ok = server.alloc_buffer(request->input, response->output);
10761118
break;
10771119
}
10781120
case GET_ALIGNMENT: {
1079-
server.get_alignment(output);
1121+
server.get_alignment(response->output);
10801122
break;
10811123
}
10821124
case GET_MAX_SIZE: {
1083-
server.get_max_size(output);
1125+
server.get_max_size(response->output);
10841126
break;
10851127
}
10861128
case BUFFER_GET_BASE: {
1087-
ok = server.buffer_get_base(input, output);
1129+
ok = server.buffer_get_base(request->input, response->output);
10881130
break;
10891131
}
10901132
case FREE_BUFFER: {
1091-
ok = server.free_buffer(input);
1133+
ok = server.free_buffer(request->input);
10921134
break;
10931135
}
10941136
case BUFFER_CLEAR: {
1095-
ok = server.buffer_clear(input);
1137+
ok = server.buffer_clear(request->input);
10961138
break;
10971139
}
10981140
case SET_TENSOR: {
1099-
ok = server.set_tensor(input);
1141+
ok = server.set_tensor(request->input);
11001142
break;
11011143
}
11021144
case GET_TENSOR: {
1103-
ok = server.get_tensor(input, output);
1145+
ok = server.get_tensor(request->input, response->output);
11041146
break;
11051147
}
11061148
case COPY_TENSOR: {
1107-
ok = server.copy_tensor(input, output);
1149+
ok = server.copy_tensor(request->input, response->output);
11081150
break;
11091151
}
11101152
case GRAPH_COMPUTE: {
1111-
ok = server.graph_compute(input, output);
1153+
ok = server.graph_compute(request->input, response->output);
1154+
break;
1155+
}
1156+
case GET_DEVICE_MEMORY: {
1157+
break;
1158+
}
1159+
case FREE_ALL_BUFFERS: {
1160+
server.free_all_buffers();
1161+
continue;
1162+
}
1163+
default: {
1164+
fprintf(stderr, "Unknown command: %d\n", request->cmd);
1165+
ok = false;
1166+
}
1167+
}
1168+
response->status = ok;
1169+
request->response_queue->push(response);
1170+
}
1171+
}
1172+
1173+
static void rpc_serve_client(request_queue_ptr requestq, sockfd_t sockfd, size_t free_mem, size_t total_mem) {
1174+
auto responseq = std::make_shared<response_queue>();
1175+
while (true) {
1176+
uint8_t cmd;
1177+
if (!recv_data(sockfd, &cmd, 1)) {
1178+
break;
1179+
}
1180+
auto request = std::make_shared<rpc_request>();
1181+
request->cmd = (rpc_cmd)cmd;
1182+
request->response_queue = responseq;
1183+
uint64_t input_size;
1184+
if (!recv_data(sockfd, &input_size, sizeof(input_size))) {
1185+
break;
1186+
}
1187+
request->input.resize(input_size);
1188+
if (!recv_data(sockfd, request->input.data(), input_size)) {
1189+
break;
1190+
}
1191+
bool ok = true;
1192+
auto response = std::make_shared<rpc_response>();
1193+
switch (cmd) {
1194+
case ALLOC_BUFFER:
1195+
case GET_ALIGNMENT:
1196+
case GET_MAX_SIZE:
1197+
case BUFFER_GET_BASE:
1198+
case FREE_BUFFER:
1199+
case BUFFER_CLEAR:
1200+
case SET_TENSOR:
1201+
case GET_TENSOR:
1202+
case COPY_TENSOR:
1203+
case GRAPH_COMPUTE: {
1204+
requestq->push(request);
1205+
responseq->pop(&response);
11121206
break;
11131207
}
11141208
case GET_DEVICE_MEMORY: {
11151209
// output serialization format: | free (8 bytes) | total (8 bytes) |
1116-
output.resize(2*sizeof(uint64_t), 0);
1117-
memcpy(output.data(), &free_mem, sizeof(free_mem));
1118-
memcpy(output.data() + sizeof(uint64_t), &total_mem, sizeof(total_mem));
1210+
response->output.resize(2*sizeof(uint64_t), 0);
1211+
memcpy(response->output.data(), &free_mem, sizeof(free_mem));
1212+
memcpy(response->output.data() + sizeof(uint64_t), &total_mem, sizeof(total_mem));
11191213
break;
11201214
}
11211215
default: {
@@ -1126,17 +1220,29 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
11261220
if (!ok) {
11271221
break;
11281222
}
1129-
uint64_t output_size = output.size();
1223+
uint64_t output_size = response->output.size();
11301224
if (!send_data(sockfd, &output_size, sizeof(output_size))) {
11311225
break;
11321226
}
1133-
if (!send_data(sockfd, output.data(), output_size)) {
1227+
if (!send_data(sockfd, response->output.data(), output_size)) {
11341228
break;
11351229
}
11361230
}
1231+
auto request = std::make_shared<rpc_request>();
1232+
request->cmd = FREE_ALL_BUFFERS;
1233+
requestq->push(request);
11371234
}
11381235

11391236
void start_rpc_server(ggml_backend_t backend, const char * endpoint, size_t free_mem, size_t total_mem) {
1237+
#ifndef _WIN32
1238+
// prevent SIGPIPE when writing to closed socket
1239+
signal(SIGPIPE, SIG_IGN);
1240+
#endif
1241+
auto requestq = std::make_shared<request_queue>();
1242+
std::thread backend_thread = std::thread([=] {
1243+
process_requests(backend, requestq);
1244+
});
1245+
11401246
std::string host;
11411247
int port;
11421248
if (!parse_endpoint(endpoint, host, port)) {
@@ -1164,7 +1270,7 @@ void start_rpc_server(ggml_backend_t backend, const char * endpoint, size_t free
11641270
return;
11651271
}
11661272
printf("Accepted client connection, free_mem=%zu, total_mem=%zu\n", free_mem, total_mem);
1167-
rpc_serve_client(backend, client_socket->fd, free_mem, total_mem);
1273+
rpc_serve_client(requestq, client_socket->fd, free_mem, total_mem);
11681274
printf("Client connection closed\n");
11691275
}
11701276
#ifdef _WIN32

0 commit comments

Comments
 (0)