Skip to content

Commit 1e37894

Browse files
committed
rpc : prevent crashes on invalid input
Add more checks which prevent RPC server from crashing if invalid input is received from client
1 parent 98a532d commit 1e37894

File tree

1 file changed

+13
-1
lines changed

1 file changed

+13
-1
lines changed

ggml/src/ggml-rpc.cpp

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1098,13 +1098,23 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
10981098
if (!recv_data(sockfd, &cmd, 1)) {
10991099
break;
11001100
}
1101+
if (cmd > GET_DEVICE_MEMORY) {
1102+
// fail fast if the command is invalid
1103+
fprintf(stderr, "Unknown command: %d\n", cmd);
1104+
break;
1105+
}
11011106
std::vector<uint8_t> input;
11021107
std::vector<uint8_t> output;
11031108
uint64_t input_size;
11041109
if (!recv_data(sockfd, &input_size, sizeof(input_size))) {
11051110
break;
11061111
}
1107-
input.resize(input_size);
1112+
try {
1113+
input.resize(input_size);
1114+
} catch (const std::bad_alloc & e) {
1115+
fprintf(stderr, "Failed to allocate input buffer of size %" PRIu64 "\n", input_size);
1116+
break;
1117+
}
11081118
if (!recv_data(sockfd, input.data(), input_size)) {
11091119
break;
11101120
}
@@ -1203,8 +1213,10 @@ void start_rpc_server(ggml_backend_t backend, const char * endpoint, size_t free
12031213
return;
12041214
}
12051215
printf("Accepted client connection, free_mem=%zu, total_mem=%zu\n", free_mem, total_mem);
1216+
fflush(stdout);
12061217
rpc_serve_client(backend, client_socket->fd, free_mem, total_mem);
12071218
printf("Client connection closed\n");
1219+
fflush(stdout);
12081220
}
12091221
#ifdef _WIN32
12101222
WSACleanup();

0 commit comments

Comments
 (0)