5
5
#include < cinttypes>
6
6
#include < string>
7
7
#include < vector>
8
+ #include < queue>
8
9
#include < memory>
9
10
#include < mutex>
11
+ #include < thread>
12
+ #include < condition_variable>
10
13
#include < unordered_map>
11
14
#include < unordered_set>
12
15
#ifdef _WIN32
17
20
# include < windows.h>
18
21
# include < winsock2.h>
19
22
#else
23
+ # include < signal.h>
20
24
# include < arpa/inet.h>
21
25
# include < sys/socket.h>
22
26
# include < sys/types.h>
@@ -89,6 +93,7 @@ enum rpc_cmd {
89
93
COPY_TENSOR,
90
94
GRAPH_COMPUTE,
91
95
GET_DEVICE_MEMORY,
96
+ FREE_ALL_BUFFERS,
92
97
};
93
98
94
99
// RPC data structures
@@ -736,6 +741,48 @@ GGML_API GGML_CALL void ggml_backend_rpc_get_device_memory(const char * endpoint
736
741
737
742
// RPC server-side implementation
738
743
744
+ template <typename T>
745
+ class message_queue {
746
+ std::queue<T> queue;
747
+ std::mutex mutex;
748
+ std::condition_variable cvar;
749
+
750
+ public:
751
+ message_queue () {}
752
+
753
+ void push (const T &value) {
754
+ std::unique_lock<std::mutex> lock (mutex);
755
+ queue.push (value);
756
+ lock.unlock ();
757
+ cvar.notify_all ();
758
+ }
759
+
760
+ void pop (T* out) {
761
+ std::unique_lock<std::mutex> lock (mutex);
762
+ cvar.wait (lock, [this ] { return queue.size () > 0 ; });
763
+ *out = queue.front ();
764
+ queue.pop ();
765
+ }
766
+ };
767
+
768
+ struct rpc_response {
769
+ std::vector<uint8_t > output;
770
+ bool status;
771
+ };
772
+
773
+ using rpc_response_ptr = std::shared_ptr<rpc_response>;
774
+ using response_queue = message_queue<rpc_response_ptr>;
775
+ using response_queue_ptr = std::shared_ptr<response_queue>;
776
+
777
+ struct rpc_request {
778
+ rpc_cmd cmd;
779
+ std::vector<uint8_t > input;
780
+ response_queue_ptr response_queue;
781
+ };
782
+ using rpc_request_ptr = std::shared_ptr<rpc_request>;
783
+ using request_queue = message_queue<rpc_request_ptr>;
784
+ using request_queue_ptr = std::shared_ptr<request_queue>;
785
+
739
786
class rpc_server {
740
787
public:
741
788
rpc_server (ggml_backend_t backend) : backend(backend) {}
@@ -752,6 +799,7 @@ class rpc_server {
752
799
bool copy_tensor (const std::vector<uint8_t > & input, std::vector<uint8_t > & output);
753
800
bool graph_compute (const std::vector<uint8_t > & input, std::vector<uint8_t > & output);
754
801
802
+ void free_all_buffers ();
755
803
private:
756
804
ggml_tensor * deserialize_tensor (struct ggml_context * ctx, const rpc_tensor * tensor);
757
805
ggml_tensor * create_node (uint64_t id,
@@ -1046,76 +1094,122 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input, std::vector<u
1046
1094
return true ;
1047
1095
}
1048
1096
1049
- rpc_server::~rpc_server () {
1097
+ void rpc_server::free_all_buffers () {
1050
1098
for (auto buffer : buffers) {
1051
1099
ggml_backend_buffer_free (buffer);
1052
1100
}
1101
+ buffers.clear ();
1102
+ }
1103
+
1104
+ rpc_server::~rpc_server () {
1105
+ free_all_buffers ();
1053
1106
}
1054
1107
1055
- static void rpc_serve_client (ggml_backend_t backend, sockfd_t sockfd, size_t free_mem, size_t total_mem ) {
1108
+ static void process_requests (ggml_backend_t backend, request_queue_ptr requestq ) {
1056
1109
rpc_server server (backend);
1057
1110
while (true ) {
1058
- uint8_t cmd;
1059
- if (!recv_data (sockfd, &cmd, 1 )) {
1060
- break ;
1061
- }
1062
- std::vector<uint8_t > input;
1063
- std::vector<uint8_t > output;
1064
- uint64_t input_size;
1065
- if (!recv_data (sockfd, &input_size, sizeof (input_size))) {
1066
- break ;
1067
- }
1068
- input.resize (input_size);
1069
- if (!recv_data (sockfd, input.data (), input_size)) {
1070
- break ;
1071
- }
1111
+ rpc_request_ptr request;
1112
+ requestq->pop (&request);
1113
+ rpc_response_ptr response = std::make_shared<rpc_response>();
1072
1114
bool ok = true ;
1073
- switch (cmd) {
1115
+ switch (request-> cmd ) {
1074
1116
case ALLOC_BUFFER: {
1075
- ok = server.alloc_buffer (input, output);
1117
+ ok = server.alloc_buffer (request-> input , response-> output );
1076
1118
break ;
1077
1119
}
1078
1120
case GET_ALIGNMENT: {
1079
- server.get_alignment (output);
1121
+ server.get_alignment (response-> output );
1080
1122
break ;
1081
1123
}
1082
1124
case GET_MAX_SIZE: {
1083
- server.get_max_size (output);
1125
+ server.get_max_size (response-> output );
1084
1126
break ;
1085
1127
}
1086
1128
case BUFFER_GET_BASE: {
1087
- ok = server.buffer_get_base (input, output);
1129
+ ok = server.buffer_get_base (request-> input , response-> output );
1088
1130
break ;
1089
1131
}
1090
1132
case FREE_BUFFER: {
1091
- ok = server.free_buffer (input);
1133
+ ok = server.free_buffer (request-> input );
1092
1134
break ;
1093
1135
}
1094
1136
case BUFFER_CLEAR: {
1095
- ok = server.buffer_clear (input);
1137
+ ok = server.buffer_clear (request-> input );
1096
1138
break ;
1097
1139
}
1098
1140
case SET_TENSOR: {
1099
- ok = server.set_tensor (input);
1141
+ ok = server.set_tensor (request-> input );
1100
1142
break ;
1101
1143
}
1102
1144
case GET_TENSOR: {
1103
- ok = server.get_tensor (input, output);
1145
+ ok = server.get_tensor (request-> input , response-> output );
1104
1146
break ;
1105
1147
}
1106
1148
case COPY_TENSOR: {
1107
- ok = server.copy_tensor (input, output);
1149
+ ok = server.copy_tensor (request-> input , response-> output );
1108
1150
break ;
1109
1151
}
1110
1152
case GRAPH_COMPUTE: {
1111
- ok = server.graph_compute (input, output);
1153
+ ok = server.graph_compute (request->input , response->output );
1154
+ break ;
1155
+ }
1156
+ case GET_DEVICE_MEMORY: {
1157
+ break ;
1158
+ }
1159
+ case FREE_ALL_BUFFERS: {
1160
+ server.free_all_buffers ();
1161
+ continue ;
1162
+ }
1163
+ default : {
1164
+ fprintf (stderr, " Unknown command: %d\n " , request->cmd );
1165
+ ok = false ;
1166
+ }
1167
+ }
1168
+ response->status = ok;
1169
+ request->response_queue ->push (response);
1170
+ }
1171
+ }
1172
+
1173
+ static void rpc_serve_client (request_queue_ptr requestq, sockfd_t sockfd, size_t free_mem, size_t total_mem) {
1174
+ auto responseq = std::make_shared<response_queue>();
1175
+ while (true ) {
1176
+ uint8_t cmd;
1177
+ if (!recv_data (sockfd, &cmd, 1 )) {
1178
+ break ;
1179
+ }
1180
+ auto request = std::make_shared<rpc_request>();
1181
+ request->cmd = (rpc_cmd)cmd;
1182
+ request->response_queue = responseq;
1183
+ uint64_t input_size;
1184
+ if (!recv_data (sockfd, &input_size, sizeof (input_size))) {
1185
+ break ;
1186
+ }
1187
+ request->input .resize (input_size);
1188
+ if (!recv_data (sockfd, request->input .data (), input_size)) {
1189
+ break ;
1190
+ }
1191
+ bool ok = true ;
1192
+ auto response = std::make_shared<rpc_response>();
1193
+ switch (cmd) {
1194
+ case ALLOC_BUFFER:
1195
+ case GET_ALIGNMENT:
1196
+ case GET_MAX_SIZE:
1197
+ case BUFFER_GET_BASE:
1198
+ case FREE_BUFFER:
1199
+ case BUFFER_CLEAR:
1200
+ case SET_TENSOR:
1201
+ case GET_TENSOR:
1202
+ case COPY_TENSOR:
1203
+ case GRAPH_COMPUTE: {
1204
+ requestq->push (request);
1205
+ responseq->pop (&response);
1112
1206
break ;
1113
1207
}
1114
1208
case GET_DEVICE_MEMORY: {
1115
1209
// output serialization format: | free (8 bytes) | total (8 bytes) |
1116
- output.resize (2 *sizeof (uint64_t ), 0 );
1117
- memcpy (output.data (), &free_mem, sizeof (free_mem));
1118
- memcpy (output.data () + sizeof (uint64_t ), &total_mem, sizeof (total_mem));
1210
+ response-> output .resize (2 *sizeof (uint64_t ), 0 );
1211
+ memcpy (response-> output .data (), &free_mem, sizeof (free_mem));
1212
+ memcpy (response-> output .data () + sizeof (uint64_t ), &total_mem, sizeof (total_mem));
1119
1213
break ;
1120
1214
}
1121
1215
default : {
@@ -1126,17 +1220,29 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
1126
1220
if (!ok) {
1127
1221
break ;
1128
1222
}
1129
- uint64_t output_size = output.size ();
1223
+ uint64_t output_size = response-> output .size ();
1130
1224
if (!send_data (sockfd, &output_size, sizeof (output_size))) {
1131
1225
break ;
1132
1226
}
1133
- if (!send_data (sockfd, output.data (), output_size)) {
1227
+ if (!send_data (sockfd, response-> output .data (), output_size)) {
1134
1228
break ;
1135
1229
}
1136
1230
}
1231
+ auto request = std::make_shared<rpc_request>();
1232
+ request->cmd = FREE_ALL_BUFFERS;
1233
+ requestq->push (request);
1137
1234
}
1138
1235
1139
1236
void start_rpc_server (ggml_backend_t backend, const char * endpoint, size_t free_mem, size_t total_mem) {
1237
+ #ifndef _WIN32
1238
+ // prevent SIGPIPE when writing to closed socket
1239
+ signal (SIGPIPE, SIG_IGN);
1240
+ #endif
1241
+ auto requestq = std::make_shared<request_queue>();
1242
+ std::thread backend_thread = std::thread ([=] {
1243
+ process_requests (backend, requestq);
1244
+ });
1245
+
1140
1246
std::string host;
1141
1247
int port;
1142
1248
if (!parse_endpoint (endpoint, host, port)) {
@@ -1164,7 +1270,7 @@ void start_rpc_server(ggml_backend_t backend, const char * endpoint, size_t free
1164
1270
return ;
1165
1271
}
1166
1272
printf (" Accepted client connection, free_mem=%zu, total_mem=%zu\n " , free_mem, total_mem);
1167
- rpc_serve_client (backend , client_socket->fd , free_mem, total_mem);
1273
+ rpc_serve_client (requestq , client_socket->fd , free_mem, total_mem);
1168
1274
printf (" Client connection closed\n " );
1169
1275
}
1170
1276
#ifdef _WIN32
0 commit comments