5
5
#include < cinttypes>
6
6
#include < string>
7
7
#include < vector>
8
+ #include < queue>
8
9
#include < memory>
9
10
#include < mutex>
11
+ #include < thread>
12
+ #include < condition_variable>
10
13
#include < unordered_map>
11
14
#include < unordered_set>
12
15
#ifdef _WIN32
@@ -736,6 +739,48 @@ GGML_API GGML_CALL void ggml_backend_rpc_get_device_memory(const char * endpoint
736
739
737
740
// RPC server-side implementation
738
741
742
+ template <typename T>
743
+ class message_queue {
744
+ std::queue<T> queue;
745
+ std::mutex mutex;
746
+ std::condition_variable cvar;
747
+
748
+ public:
749
+ message_queue () {}
750
+
751
+ void push (const T &value) {
752
+ std::unique_lock<std::mutex> lock (mutex);
753
+ queue.push (value);
754
+ lock.unlock ();
755
+ cvar.notify_all ();
756
+ }
757
+
758
+ void pop (T* out) {
759
+ std::unique_lock<std::mutex> lock (mutex);
760
+ cvar.wait (lock, [this ] { return queue.size () > 0 ; });
761
+ *out = queue.front ();
762
+ queue.pop ();
763
+ }
764
+ };
765
+
766
+ struct rpc_response {
767
+ std::vector<uint8_t > output;
768
+ bool status;
769
+ };
770
+
771
+ using rpc_response_ptr = std::shared_ptr<rpc_response>;
772
+ using response_queue = message_queue<rpc_response_ptr>;
773
+ using response_queue_ptr = std::shared_ptr<response_queue>;
774
+
775
+ struct rpc_request {
776
+ rpc_cmd cmd;
777
+ std::vector<uint8_t > input;
778
+ response_queue_ptr response_queue;
779
+ };
780
+ using rpc_request_ptr = std::shared_ptr<rpc_request>;
781
+ using request_queue = message_queue<rpc_request_ptr>;
782
+ using request_queue_ptr = std::shared_ptr<request_queue>;
783
+
739
784
class rpc_server {
740
785
public:
741
786
rpc_server (ggml_backend_t backend) : backend(backend) {}
@@ -1052,70 +1097,107 @@ rpc_server::~rpc_server() {
1052
1097
}
1053
1098
}
1054
1099
1055
- static void rpc_serve_client (ggml_backend_t backend, sockfd_t sockfd, size_t free_mem, size_t total_mem ) {
1100
+ static void process_requests (ggml_backend_t backend, request_queue_ptr requestq ) {
1056
1101
rpc_server server (backend);
1057
1102
while (true ) {
1058
- uint8_t cmd;
1059
- if (!recv_data (sockfd, &cmd, 1 )) {
1060
- break ;
1061
- }
1062
- std::vector<uint8_t > input;
1063
- std::vector<uint8_t > output;
1064
- uint64_t input_size;
1065
- if (!recv_data (sockfd, &input_size, sizeof (input_size))) {
1066
- break ;
1067
- }
1068
- input.resize (input_size);
1069
- if (!recv_data (sockfd, input.data (), input_size)) {
1070
- break ;
1071
- }
1103
+ rpc_request_ptr request;
1104
+ requestq->pop (&request);
1105
+ rpc_response_ptr response = std::make_shared<rpc_response>();
1072
1106
bool ok = true ;
1073
- switch (cmd) {
1107
+ switch (request-> cmd ) {
1074
1108
case ALLOC_BUFFER: {
1075
- ok = server.alloc_buffer (input, output);
1109
+ ok = server.alloc_buffer (request-> input , response-> output );
1076
1110
break ;
1077
1111
}
1078
1112
case GET_ALIGNMENT: {
1079
- server.get_alignment (output);
1113
+ server.get_alignment (response-> output );
1080
1114
break ;
1081
1115
}
1082
1116
case GET_MAX_SIZE: {
1083
- server.get_max_size (output);
1117
+ server.get_max_size (response-> output );
1084
1118
break ;
1085
1119
}
1086
1120
case BUFFER_GET_BASE: {
1087
- ok = server.buffer_get_base (input, output);
1121
+ ok = server.buffer_get_base (request-> input , response-> output );
1088
1122
break ;
1089
1123
}
1090
1124
case FREE_BUFFER: {
1091
- ok = server.free_buffer (input);
1125
+ ok = server.free_buffer (request-> input );
1092
1126
break ;
1093
1127
}
1094
1128
case BUFFER_CLEAR: {
1095
- ok = server.buffer_clear (input);
1129
+ ok = server.buffer_clear (request-> input );
1096
1130
break ;
1097
1131
}
1098
1132
case SET_TENSOR: {
1099
- ok = server.set_tensor (input);
1133
+ ok = server.set_tensor (request-> input );
1100
1134
break ;
1101
1135
}
1102
1136
case GET_TENSOR: {
1103
- ok = server.get_tensor (input, output);
1137
+ ok = server.get_tensor (request-> input , response-> output );
1104
1138
break ;
1105
1139
}
1106
1140
case COPY_TENSOR: {
1107
- ok = server.copy_tensor (input, output);
1141
+ ok = server.copy_tensor (request-> input , response-> output );
1108
1142
break ;
1109
1143
}
1110
1144
case GRAPH_COMPUTE: {
1111
- ok = server.graph_compute (input, output);
1145
+ ok = server.graph_compute (request->input , response->output );
1146
+ break ;
1147
+ }
1148
+ case GET_DEVICE_MEMORY: {
1149
+ break ;
1150
+ }
1151
+ default : {
1152
+ fprintf (stderr, " Unknown command: %d\n " , request->cmd );
1153
+ ok = false ;
1154
+ }
1155
+ }
1156
+ response->status = ok;
1157
+ request->response_queue ->push (response);
1158
+ }
1159
+ }
1160
+
1161
+ static void rpc_serve_client (request_queue_ptr requestq, sockfd_t sockfd, size_t free_mem, size_t total_mem) {
1162
+ auto responseq = std::make_shared<response_queue>();
1163
+ while (true ) {
1164
+ uint8_t cmd;
1165
+ if (!recv_data (sockfd, &cmd, 1 )) {
1166
+ break ;
1167
+ }
1168
+ auto request = std::make_shared<rpc_request>();
1169
+ request->cmd = (rpc_cmd)cmd;
1170
+ request->response_queue = responseq;
1171
+ uint64_t input_size;
1172
+ if (!recv_data (sockfd, &input_size, sizeof (input_size))) {
1173
+ break ;
1174
+ }
1175
+ request->input .resize (input_size);
1176
+ if (!recv_data (sockfd, request->input .data (), input_size)) {
1177
+ break ;
1178
+ }
1179
+ bool ok = true ;
1180
+ auto response = std::make_shared<rpc_response>();
1181
+ switch (cmd) {
1182
+ case ALLOC_BUFFER:
1183
+ case GET_ALIGNMENT:
1184
+ case GET_MAX_SIZE:
1185
+ case BUFFER_GET_BASE:
1186
+ case FREE_BUFFER:
1187
+ case BUFFER_CLEAR:
1188
+ case SET_TENSOR:
1189
+ case GET_TENSOR:
1190
+ case COPY_TENSOR:
1191
+ case GRAPH_COMPUTE: {
1192
+ requestq->push (request);
1193
+ responseq->pop (&response);
1112
1194
break ;
1113
1195
}
1114
1196
case GET_DEVICE_MEMORY: {
1115
1197
// output serialization format: | free (8 bytes) | total (8 bytes) |
1116
- output.resize (2 *sizeof (uint64_t ), 0 );
1117
- memcpy (output.data (), &free_mem, sizeof (free_mem));
1118
- memcpy (output.data () + sizeof (uint64_t ), &total_mem, sizeof (total_mem));
1198
+ response-> output .resize (2 *sizeof (uint64_t ), 0 );
1199
+ memcpy (response-> output .data (), &free_mem, sizeof (free_mem));
1200
+ memcpy (response-> output .data () + sizeof (uint64_t ), &total_mem, sizeof (total_mem));
1119
1201
break ;
1120
1202
}
1121
1203
default : {
@@ -1126,17 +1208,22 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
1126
1208
if (!ok) {
1127
1209
break ;
1128
1210
}
1129
- uint64_t output_size = output.size ();
1211
+ uint64_t output_size = response-> output .size ();
1130
1212
if (!send_data (sockfd, &output_size, sizeof (output_size))) {
1131
1213
break ;
1132
1214
}
1133
- if (!send_data (sockfd, output.data (), output_size)) {
1215
+ if (!send_data (sockfd, response-> output .data (), output_size)) {
1134
1216
break ;
1135
1217
}
1136
1218
}
1137
1219
}
1138
1220
1139
1221
void start_rpc_server (ggml_backend_t backend, const char * endpoint, size_t free_mem, size_t total_mem) {
1222
+ auto requestq = std::make_shared<request_queue>();
1223
+ std::thread backend_thread = std::thread ([=] {
1224
+ process_requests (backend, requestq);
1225
+ });
1226
+
1140
1227
std::string host;
1141
1228
int port;
1142
1229
if (!parse_endpoint (endpoint, host, port)) {
@@ -1164,7 +1251,7 @@ void start_rpc_server(ggml_backend_t backend, const char * endpoint, size_t free
1164
1251
return ;
1165
1252
}
1166
1253
printf (" Accepted client connection, free_mem=%zu, total_mem=%zu\n " , free_mem, total_mem);
1167
- rpc_serve_client (backend , client_socket->fd , free_mem, total_mem);
1254
+ rpc_serve_client (requestq , client_socket->fd , free_mem, total_mem);
1168
1255
printf (" Client connection closed\n " );
1169
1256
}
1170
1257
#ifdef _WIN32
0 commit comments