Skip to content

Commit afbe094

Browse files
ggerganovarthw
authored andcommitted
Merge commit from fork
1 parent a1b432b commit afbe094

File tree

4 files changed

+53
-3
lines changed

4 files changed

+53
-3
lines changed

examples/rpc/README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
## Overview
22

3+
> [!IMPORTANT]
4+
> This example and the RPC backend are currently in a proof-of-concept development stage. As such, the functionality is fragile and
5+
> insecure. **Never run the RPC server on an open network or in a sensitive environment!**
6+
37
The `rpc-server` allows running `ggml` backend on a remote host.
48
The RPC backend communicates with one or several instances of `rpc-server` and offloads computations to them.
59
This can be used for distributed LLM inference with `llama.cpp` in the following way:

examples/rpc/rpc-server.cpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
#include <stdio.h>
1717

1818
struct rpc_server_params {
19-
std::string host = "0.0.0.0";
19+
std::string host = "127.0.0.1";
2020
int port = 50052;
2121
size_t backend_mem = 0;
2222
};
@@ -114,6 +114,17 @@ int main(int argc, char * argv[]) {
114114
fprintf(stderr, "Invalid parameters\n");
115115
return 1;
116116
}
117+
118+
if (params.host != "127.0.0.1") {
119+
fprintf(stderr, "\n");
120+
fprintf(stderr, "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\n");
121+
fprintf(stderr, "WARNING: Host ('%s') is != '127.0.0.1'\n", params.host.c_str());
122+
fprintf(stderr, " Never expose the RPC server to an open network!\n");
123+
fprintf(stderr, " This is an experimental feature and is not secure!\n");
124+
fprintf(stderr, "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\n");
125+
fprintf(stderr, "\n");
126+
}
127+
117128
ggml_backend_t backend = create_backend();
118129
if (!backend) {
119130
fprintf(stderr, "Failed to create backend\n");

ggml/src/ggml-rpc.cpp

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,10 @@ static std::shared_ptr<socket_t> create_server_socket(const char * host, int por
197197
fprintf(stderr, "Failed to set SO_REUSEADDR\n");
198198
return nullptr;
199199
}
200+
if (inet_addr(host) == INADDR_NONE) {
201+
fprintf(stderr, "Invalid host address: %s\n", host);
202+
return nullptr;
203+
}
200204
struct sockaddr_in serv_addr;
201205
serv_addr.sin_family = AF_INET;
202206
serv_addr.sin_addr.s_addr = inet_addr(host);
@@ -879,6 +883,14 @@ ggml_tensor * rpc_server::deserialize_tensor(struct ggml_context * ctx, const rp
879883
if (result->buffer && buffers.find(result->buffer) == buffers.end()) {
880884
return nullptr;
881885
}
886+
887+
// require that the tensor data does not go beyond the buffer end
888+
uint64_t tensor_size = (uint64_t) ggml_nbytes(result);
889+
uint64_t buffer_start = (uint64_t) ggml_backend_buffer_get_base(result->buffer);
890+
uint64_t buffer_size = (uint64_t) ggml_backend_buffer_get_size(result->buffer);
891+
GGML_ASSERT(tensor->data + tensor_size >= tensor->data); // check for overflow
892+
GGML_ASSERT(tensor->data >= buffer_start && tensor->data + tensor_size <= buffer_start + buffer_size);
893+
882894
result->op = (ggml_op) tensor->op;
883895
for (uint32_t i = 0; i < GGML_MAX_OP_PARAMS / sizeof(int32_t); i++) {
884896
result->op_params[i] = tensor->op_params[i];
@@ -898,7 +910,7 @@ bool rpc_server::set_tensor(const std::vector<uint8_t> & input) {
898910
const rpc_tensor * in_tensor = (const rpc_tensor *)input.data();
899911
uint64_t offset;
900912
memcpy(&offset, input.data() + sizeof(rpc_tensor), sizeof(offset));
901-
size_t size = input.size() - sizeof(rpc_tensor) - sizeof(offset);
913+
const size_t size = input.size() - sizeof(rpc_tensor) - sizeof(offset);
902914

903915
struct ggml_init_params params {
904916
/*.mem_size =*/ ggml_tensor_overhead(),
@@ -913,6 +925,17 @@ bool rpc_server::set_tensor(const std::vector<uint8_t> & input) {
913925
return false;
914926
}
915927
GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %zu\n", __func__, (void*)tensor->buffer, tensor->data, offset, size);
928+
929+
// sanitize tensor->data
930+
{
931+
const size_t p0 = (size_t) ggml_backend_buffer_get_base(tensor->buffer);
932+
const size_t p1 = p0 + ggml_backend_buffer_get_size(tensor->buffer);
933+
934+
if (in_tensor->data + offset < p0 || in_tensor->data + offset >= p1 || size > (p1 - in_tensor->data - offset)) {
935+
GGML_ABORT("[%s] tensor->data out of bounds\n", __func__);
936+
}
937+
}
938+
916939
const void * data = input.data() + sizeof(rpc_tensor) + sizeof(offset);
917940
ggml_backend_tensor_set(tensor, data, offset, size);
918941
ggml_free(ctx);
@@ -943,6 +966,17 @@ bool rpc_server::get_tensor(const std::vector<uint8_t> & input, std::vector<uint
943966
return false;
944967
}
945968
GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %" PRIu64 "\n", __func__, (void*)tensor->buffer, tensor->data, offset, size);
969+
970+
// sanitize tensor->data
971+
{
972+
const size_t p0 = (size_t) ggml_backend_buffer_get_base(tensor->buffer);
973+
const size_t p1 = p0 + ggml_backend_buffer_get_size(tensor->buffer);
974+
975+
if (in_tensor->data + offset < p0 || in_tensor->data + offset >= p1 || size > (p1 - in_tensor->data - offset)) {
976+
GGML_ABORT("[%s] tensor->data out of bounds\n", __func__);
977+
}
978+
}
979+
946980
// output serialization format: | data (size bytes) |
947981
output.resize(size, 0);
948982
ggml_backend_tensor_get(tensor, output.data(), offset, size);

ggml/src/ggml.c

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3724,7 +3724,8 @@ static struct ggml_tensor * ggml_new_tensor_impl(
37243724
struct ggml_tensor * view_src,
37253725
size_t view_offs) {
37263726

3727-
assert(n_dims >= 1 && n_dims <= GGML_MAX_DIMS);
3727+
GGML_ASSERT(type >= 0 && type < GGML_TYPE_COUNT);
3728+
GGML_ASSERT(n_dims >= 1 && n_dims <= GGML_MAX_DIMS);
37283729

37293730
// find the base tensor and absolute offset
37303731
if (view_src != NULL && view_src->view_src != NULL) {

0 commit comments

Comments
 (0)