Skip to content

Commit 2ce50d1

Browse files
committed
rpc : resource management rework
1 parent 1d8fca7 commit 2ce50d1

File tree

1 file changed

+87
-49
lines changed

1 file changed

+87
-49
lines changed

ggml-rpc.cpp

Lines changed: 87 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -96,27 +96,37 @@ static ggml_guid_t ggml_backend_rpc_guid() {
9696
return &guid;
9797
}
9898

99-
struct ggml_backend_rpc_buffer_type_context {
99+
struct rpc_backend {
100+
int ref_count;
101+
std::string endpoint;
100102
std::shared_ptr<socket_t> sock;
103+
ggml_backend_t backend;
104+
};
105+
106+
using rpc_backend_ptr = std::shared_ptr<rpc_backend>;
107+
108+
struct ggml_backend_rpc_buffer_type_context {
109+
std::shared_ptr<rpc_backend> back;
101110
std::string name;
102111
size_t alignment;
103112
size_t max_size;
104113
};
105114

106115
struct ggml_backend_rpc_context {
107-
std::string endpoint;
108116
std::string name;
109-
std::shared_ptr<socket_t> sock;
117+
std::shared_ptr<rpc_backend> back;
110118
ggml_backend_buffer_type_t buft;
111119
};
112120

113121
struct ggml_backend_rpc_buffer_context {
114-
std::shared_ptr<socket_t> sock;
122+
std::shared_ptr<rpc_backend> back;
115123
std::unordered_map<ggml_backend_buffer_t, void *> base_cache;
116124
uint64_t remote_ptr;
117125
std::string name;
118126
};
119127

128+
static std::unordered_map<std::string, rpc_backend_ptr> instances;
129+
120130
// RPC helper functions
121131

122132
static std::shared_ptr<socket_t> make_socket(sockfd_t fd) {
@@ -231,14 +241,13 @@ static bool recv_data(sockfd_t sockfd, void * data, size_t size) {
231241
return true;
232242
}
233243

234-
static bool parse_endpoint(const char * endpoint, std::string & host, int & port) {
235-
std::string str(endpoint);
236-
size_t pos = str.find(':');
244+
static bool parse_endpoint(const std::string & endpoint, std::string & host, int & port) {
245+
size_t pos = endpoint.find(':');
237246
if (pos == std::string::npos) {
238247
return false;
239248
}
240-
host = str.substr(0, pos);
241-
port = std::stoi(str.substr(pos + 1));
249+
host = endpoint.substr(0, pos);
250+
port = std::stoi(endpoint.substr(pos + 1));
242251
return true;
243252
}
244253

@@ -273,6 +282,22 @@ static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cm
273282

274283
// RPC client-side implementation
275284

285+
static void free_rpc_backend(rpc_backend_ptr rpc_back) {
286+
ggml_backend_t backend = rpc_back->backend;
287+
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
288+
std::string endpoint = rpc_back->endpoint;
289+
ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)rpc_ctx->buft->context;
290+
GGML_PRINT_DEBUG("[%s] closing connection to %s\n", __func__, endpoint.c_str());
291+
delete buft_ctx;
292+
delete rpc_ctx->buft;
293+
delete rpc_ctx;
294+
delete backend;
295+
instances.erase(endpoint);
296+
#ifdef _WIN32
297+
WSACleanup();
298+
#endif
299+
}
300+
276301
GGML_CALL static const char * ggml_backend_rpc_buffer_get_name(ggml_backend_buffer_t buffer) {
277302
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
278303
return ctx->name.c_str();
@@ -285,9 +310,13 @@ GGML_CALL static void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t
285310
uint64_t remote_ptr = ctx->remote_ptr;
286311
memcpy(input.data(), &remote_ptr, sizeof(remote_ptr));
287312
std::vector<uint8_t> output;
288-
bool status = send_rpc_cmd(ctx->sock, FREE_BUFFER, input, output);
313+
bool status = send_rpc_cmd(ctx->back->sock, FREE_BUFFER, input, output);
289314
GGML_ASSERT(status);
290315
GGML_ASSERT(output.empty());
316+
ctx->back->ref_count--;
317+
if (ctx->back->ref_count == 0) {
318+
free_rpc_backend(ctx->back);
319+
}
291320
delete ctx;
292321
}
293322

@@ -301,7 +330,7 @@ GGML_CALL static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t b
301330
uint64_t remote_ptr = ctx->remote_ptr;
302331
memcpy(input.data(), &remote_ptr, sizeof(remote_ptr));
303332
std::vector<uint8_t> output;
304-
bool status = send_rpc_cmd(ctx->sock, BUFFER_GET_BASE, input, output);
333+
bool status = send_rpc_cmd(ctx->back->sock, BUFFER_GET_BASE, input, output);
305334
GGML_ASSERT(status);
306335
GGML_ASSERT(output.size() == sizeof(uint64_t));
307336
// output serialization format: | base_ptr (8 bytes) |
@@ -360,7 +389,7 @@ GGML_CALL static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t b
360389
memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset));
361390
memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), data, size);
362391
std::vector<uint8_t> output;
363-
bool status = send_rpc_cmd(ctx->sock, SET_TENSOR, input, output);
392+
bool status = send_rpc_cmd(ctx->back->sock, SET_TENSOR, input, output);
364393
GGML_ASSERT(status);
365394
}
366395

@@ -374,7 +403,7 @@ GGML_CALL static void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t b
374403
memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset));
375404
memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), &size, sizeof(size));
376405
std::vector<uint8_t> output;
377-
bool status = send_rpc_cmd(ctx->sock, GET_TENSOR, input, output);
406+
bool status = send_rpc_cmd(ctx->back->sock, GET_TENSOR, input, output);
378407
GGML_ASSERT(status);
379408
GGML_ASSERT(output.size() == size);
380409
// output serialization format: | data (size bytes) |
@@ -387,7 +416,7 @@ GGML_CALL static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t b
387416
ggml_backend_rpc_buffer_context * src_ctx = (ggml_backend_rpc_buffer_context *)src_buffer->context;
388417
ggml_backend_buffer_t dst_buffer = dst->buffer;
389418
ggml_backend_rpc_buffer_context * dst_ctx = (ggml_backend_rpc_buffer_context *)dst_buffer->context;
390-
if (src_ctx->sock != dst_ctx->sock) {
419+
if (src_ctx->back != dst_ctx->back) {
391420
return false;
392421
}
393422
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
@@ -399,7 +428,7 @@ GGML_CALL static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t b
399428
memcpy(input.data(), &rpc_src, sizeof(rpc_src));
400429
memcpy(input.data() + sizeof(rpc_src), &rpc_dst, sizeof(rpc_dst));
401430
std::vector<uint8_t> output;
402-
bool status = send_rpc_cmd(ctx->sock, COPY_TENSOR, input, output);
431+
bool status = send_rpc_cmd(ctx->back->sock, COPY_TENSOR, input, output);
403432
GGML_ASSERT(status);
404433
// output serialization format: | result (1 byte) |
405434
GGML_ASSERT(output.size() == 1);
@@ -414,7 +443,7 @@ GGML_CALL static void ggml_backend_rpc_buffer_clear(ggml_backend_buffer_t buffer
414443
memcpy(input.data(), &ctx->remote_ptr, sizeof(ctx->remote_ptr));
415444
memcpy(input.data() + sizeof(ctx->remote_ptr), &value, sizeof(value));
416445
std::vector<uint8_t> output;
417-
bool status = send_rpc_cmd(ctx->sock, BUFFER_CLEAR, input, output);
446+
bool status = send_rpc_cmd(ctx->back->sock, BUFFER_CLEAR, input, output);
418447
GGML_ASSERT(status);
419448
}
420449

@@ -442,7 +471,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer
442471
std::vector<uint8_t> input(input_size, 0);
443472
memcpy(input.data(), &size, sizeof(size));
444473
std::vector<uint8_t> output;
445-
bool status = send_rpc_cmd(buft_ctx->sock, ALLOC_BUFFER, input, output);
474+
bool status = send_rpc_cmd(buft_ctx->back->sock, ALLOC_BUFFER, input, output);
446475
GGML_ASSERT(status);
447476
GGML_ASSERT(output.size() == 2*sizeof(uint64_t));
448477
// output serialization format: | remote_ptr (8 bytes) | remote_size (8 bytes) |
@@ -453,8 +482,9 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer
453482
if (remote_ptr != 0) {
454483
ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft,
455484
ggml_backend_rpc_buffer_interface,
456-
new ggml_backend_rpc_buffer_context{buft_ctx->sock, {}, remote_ptr, "RPC"},
485+
new ggml_backend_rpc_buffer_context{buft_ctx->back, {}, remote_ptr, "RPC"},
457486
remote_size);
487+
buft_ctx->back->ref_count++;
458488
return buffer;
459489
} else {
460490
return nullptr;
@@ -508,7 +538,7 @@ GGML_CALL static bool ggml_backend_rpc_buffer_type_supports_backend(ggml_backend
508538
}
509539
ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
510540
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
511-
return buft_ctx->sock == rpc_ctx->sock;
541+
return buft_ctx->back == rpc_ctx->back;
512542
}
513543

514544
static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = {
@@ -521,7 +551,6 @@ static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = {
521551
/* .is_host = */ NULL,
522552
};
523553

524-
525554
GGML_CALL static const char * ggml_backend_rpc_name(ggml_backend_t backend) {
526555
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
527556

@@ -530,11 +559,10 @@ GGML_CALL static const char * ggml_backend_rpc_name(ggml_backend_t backend) {
530559

531560
GGML_CALL static void ggml_backend_rpc_free(ggml_backend_t backend) {
532561
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
533-
ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)rpc_ctx->buft->context;
534-
delete buft_ctx;
535-
delete rpc_ctx->buft;
536-
delete rpc_ctx;
537-
delete backend;
562+
rpc_ctx->back->ref_count--;
563+
if (rpc_ctx->back->ref_count == 0) {
564+
free_rpc_backend(rpc_ctx->back);
565+
}
538566
}
539567

540568
GGML_CALL static ggml_backend_buffer_type_t ggml_backend_rpc_get_default_buffer_type(ggml_backend_t backend) {
@@ -590,7 +618,7 @@ GGML_CALL static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t
590618
std::vector<uint8_t> input;
591619
serialize_graph(cgraph, input);
592620
std::vector<uint8_t> output;
593-
bool status = send_rpc_cmd(rpc_ctx->sock, GRAPH_COMPUTE, input, output);
621+
bool status = send_rpc_cmd(rpc_ctx->back->sock, GRAPH_COMPUTE, input, output);
594622
GGML_ASSERT(status);
595623
GGML_ASSERT(output.size() == 1);
596624
return (enum ggml_status)output[0];
@@ -624,17 +652,9 @@ static ggml_backend_i ggml_backend_rpc_interface = {
624652
/* .event_synchronize = */ NULL,
625653
};
626654

627-
static std::unordered_map<std::string, ggml_backend_t> instances;
628-
629-
GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint) {
630-
ggml_backend_t backend = ggml_backend_rpc_init(endpoint);
631-
return backend != nullptr ? ggml_backend_rpc_get_default_buffer_type(backend) : nullptr;
632-
}
633-
634-
GGML_CALL ggml_backend_t ggml_backend_rpc_init(const char * endpoint) {
635-
std::string endpoint_str(endpoint);
636-
if (instances.find(endpoint_str) != instances.end()) {
637-
return instances[endpoint_str];
655+
static rpc_backend_ptr create_rpc_backend(const std::string & endpoint) {
656+
if (instances.find(endpoint) != instances.end()) {
657+
return instances[endpoint];
638658
}
639659
#ifdef _WIN32
640660
{
@@ -645,7 +665,7 @@ GGML_CALL ggml_backend_t ggml_backend_rpc_init(const char * endpoint) {
645665
}
646666
}
647667
#endif
648-
fprintf(stderr, "Connecting to %s\n", endpoint);
668+
fprintf(stderr, "Connecting to %s\n", endpoint.c_str());
649669
std::string host;
650670
int port;
651671
if (!parse_endpoint(endpoint, host, port)) {
@@ -657,11 +677,12 @@ GGML_CALL ggml_backend_t ggml_backend_rpc_init(const char * endpoint) {
657677
}
658678
size_t alignment = get_alignment(sock);
659679
size_t max_size = get_max_size(sock);
680+
auto rpc_back = std::make_shared<rpc_backend>();
660681
ggml_backend_rpc_buffer_type_context * buft_ctx = new ggml_backend_rpc_buffer_type_context {
661-
/* .sock = */ sock,
662-
/* .name = */ "RPC" + std::to_string(sock->fd),
682+
/* .back = */ rpc_back,
683+
/* .name = */ "RPC" + std::to_string(sock->fd),
663684
/* .alignment = */ alignment,
664-
/* .max_size = */ max_size
685+
/* .max_size = */ max_size
665686
};
666687

667688
ggml_backend_buffer_type_t buft = new ggml_backend_buffer_type {
@@ -670,19 +691,37 @@ GGML_CALL ggml_backend_t ggml_backend_rpc_init(const char * endpoint) {
670691
};
671692

672693
ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context {
673-
/* .endpoint = */ endpoint,
674694
/* .name = */ "RPC" + std::to_string(sock->fd),
675-
/* .sock = */ sock,
695+
/* .back = */ rpc_back,
676696
/* .buft = */ buft
677697
};
678698

679-
instances[endpoint] = new ggml_backend {
699+
ggml_backend_t backend = new ggml_backend {
680700
/* .guid = */ ggml_backend_rpc_guid(),
681701
/* .interface = */ ggml_backend_rpc_interface,
682702
/* .context = */ ctx
683703
};
704+
rpc_back->sock = sock;
705+
rpc_back->endpoint = endpoint;
706+
rpc_back->backend = backend;
707+
rpc_back->ref_count = 0;
708+
instances[endpoint] = rpc_back;
709+
return rpc_back;
710+
}
684711

685-
return instances[endpoint];
712+
GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint) {
713+
auto rpc_back = create_rpc_backend(endpoint);
714+
return rpc_back != nullptr ? ggml_backend_rpc_get_default_buffer_type(rpc_back->backend) : nullptr;
715+
}
716+
717+
GGML_CALL ggml_backend_t ggml_backend_rpc_init(const char * endpoint) {
718+
std::string endpoint_str(endpoint);
719+
auto rpc_back = create_rpc_backend(endpoint_str);
720+
if (rpc_back == nullptr) {
721+
return nullptr;
722+
}
723+
rpc_back->ref_count++;
724+
return rpc_back->backend;
686725
}
687726

688727
GGML_API GGML_CALL bool ggml_backend_is_rpc(ggml_backend_t backend) {
@@ -706,14 +745,13 @@ static void get_device_memory(const std::shared_ptr<socket_t> & sock, size_t * f
706745
}
707746

708747
GGML_API GGML_CALL void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total) {
709-
ggml_backend_t backend = ggml_backend_rpc_init(endpoint);
710-
if (backend == nullptr) {
748+
auto rpc_back = create_rpc_backend(endpoint);
749+
if (rpc_back == nullptr) {
711750
*free = 0;
712751
*total = 0;
713752
return;
714753
}
715-
ggml_backend_rpc_context * ctx = (ggml_backend_rpc_context *)backend->context;
716-
get_device_memory(ctx->sock, free, total);
754+
get_device_memory(rpc_back->sock, free, total);
717755
}
718756

719757
// RPC server-side implementation

0 commit comments

Comments
 (0)