@@ -96,27 +96,37 @@ static ggml_guid_t ggml_backend_rpc_guid() {
96
96
return &guid;
97
97
}
98
98
99
- struct ggml_backend_rpc_buffer_type_context {
99
+ struct rpc_backend {
100
+ int ref_count;
101
+ std::string endpoint;
100
102
std::shared_ptr<socket_t > sock;
103
+ ggml_backend_t backend;
104
+ };
105
+
106
+ using rpc_backend_ptr = std::shared_ptr<rpc_backend>;
107
+
108
+ struct ggml_backend_rpc_buffer_type_context {
109
+ std::shared_ptr<rpc_backend> back;
101
110
std::string name;
102
111
size_t alignment;
103
112
size_t max_size;
104
113
};
105
114
106
115
struct ggml_backend_rpc_context {
107
- std::string endpoint;
108
116
std::string name;
109
- std::shared_ptr<socket_t > sock ;
117
+ std::shared_ptr<rpc_backend> back ;
110
118
ggml_backend_buffer_type_t buft;
111
119
};
112
120
113
121
struct ggml_backend_rpc_buffer_context {
114
- std::shared_ptr<socket_t > sock ;
122
+ std::shared_ptr<rpc_backend> back ;
115
123
std::unordered_map<ggml_backend_buffer_t , void *> base_cache;
116
124
uint64_t remote_ptr;
117
125
std::string name;
118
126
};
119
127
128
+ static std::unordered_map<std::string, rpc_backend_ptr> instances;
129
+
120
130
// RPC helper functions
121
131
122
132
static std::shared_ptr<socket_t > make_socket (sockfd_t fd) {
@@ -231,14 +241,13 @@ static bool recv_data(sockfd_t sockfd, void * data, size_t size) {
231
241
return true ;
232
242
}
233
243
234
- static bool parse_endpoint (const char * endpoint, std::string & host, int & port) {
235
- std::string str (endpoint);
236
- size_t pos = str.find (' :' );
244
+ static bool parse_endpoint (const std::string & endpoint, std::string & host, int & port) {
245
+ size_t pos = endpoint.find (' :' );
237
246
if (pos == std::string::npos) {
238
247
return false ;
239
248
}
240
- host = str .substr (0 , pos);
241
- port = std::stoi (str .substr (pos + 1 ));
249
+ host = endpoint .substr (0 , pos);
250
+ port = std::stoi (endpoint .substr (pos + 1 ));
242
251
return true ;
243
252
}
244
253
@@ -273,6 +282,22 @@ static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cm
273
282
274
283
// RPC client-side implementation
275
284
285
+ static void free_rpc_backend (rpc_backend_ptr rpc_back) {
286
+ ggml_backend_t backend = rpc_back->backend ;
287
+ ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context ;
288
+ std::string endpoint = rpc_back->endpoint ;
289
+ ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)rpc_ctx->buft ->context ;
290
+ GGML_PRINT_DEBUG (" [%s] closing connection to %s\n " , __func__, endpoint.c_str ());
291
+ delete buft_ctx;
292
+ delete rpc_ctx->buft ;
293
+ delete rpc_ctx;
294
+ delete backend;
295
+ instances.erase (endpoint);
296
+ #ifdef _WIN32
297
+ WSACleanup ();
298
+ #endif
299
+ }
300
+
276
301
GGML_CALL static const char * ggml_backend_rpc_buffer_get_name (ggml_backend_buffer_t buffer) {
277
302
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context ;
278
303
return ctx->name .c_str ();
@@ -285,9 +310,13 @@ GGML_CALL static void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t
285
310
uint64_t remote_ptr = ctx->remote_ptr ;
286
311
memcpy (input.data (), &remote_ptr, sizeof (remote_ptr));
287
312
std::vector<uint8_t > output;
288
- bool status = send_rpc_cmd (ctx->sock , FREE_BUFFER, input, output);
313
+ bool status = send_rpc_cmd (ctx->back -> sock , FREE_BUFFER, input, output);
289
314
GGML_ASSERT (status);
290
315
GGML_ASSERT (output.empty ());
316
+ ctx->back ->ref_count --;
317
+ if (ctx->back ->ref_count == 0 ) {
318
+ free_rpc_backend (ctx->back );
319
+ }
291
320
delete ctx;
292
321
}
293
322
@@ -301,7 +330,7 @@ GGML_CALL static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t b
301
330
uint64_t remote_ptr = ctx->remote_ptr ;
302
331
memcpy (input.data (), &remote_ptr, sizeof (remote_ptr));
303
332
std::vector<uint8_t > output;
304
- bool status = send_rpc_cmd (ctx->sock , BUFFER_GET_BASE, input, output);
333
+ bool status = send_rpc_cmd (ctx->back -> sock , BUFFER_GET_BASE, input, output);
305
334
GGML_ASSERT (status);
306
335
GGML_ASSERT (output.size () == sizeof (uint64_t ));
307
336
// output serialization format: | base_ptr (8 bytes) |
@@ -360,7 +389,7 @@ GGML_CALL static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t b
360
389
memcpy (input.data () + sizeof (rpc_tensor), &offset, sizeof (offset));
361
390
memcpy (input.data () + sizeof (rpc_tensor) + sizeof (offset), data, size);
362
391
std::vector<uint8_t > output;
363
- bool status = send_rpc_cmd (ctx->sock , SET_TENSOR, input, output);
392
+ bool status = send_rpc_cmd (ctx->back -> sock , SET_TENSOR, input, output);
364
393
GGML_ASSERT (status);
365
394
}
366
395
@@ -374,7 +403,7 @@ GGML_CALL static void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t b
374
403
memcpy (input.data () + sizeof (rpc_tensor), &offset, sizeof (offset));
375
404
memcpy (input.data () + sizeof (rpc_tensor) + sizeof (offset), &size, sizeof (size));
376
405
std::vector<uint8_t > output;
377
- bool status = send_rpc_cmd (ctx->sock , GET_TENSOR, input, output);
406
+ bool status = send_rpc_cmd (ctx->back -> sock , GET_TENSOR, input, output);
378
407
GGML_ASSERT (status);
379
408
GGML_ASSERT (output.size () == size);
380
409
// output serialization format: | data (size bytes) |
@@ -387,7 +416,7 @@ GGML_CALL static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t b
387
416
ggml_backend_rpc_buffer_context * src_ctx = (ggml_backend_rpc_buffer_context *)src_buffer->context ;
388
417
ggml_backend_buffer_t dst_buffer = dst->buffer ;
389
418
ggml_backend_rpc_buffer_context * dst_ctx = (ggml_backend_rpc_buffer_context *)dst_buffer->context ;
390
- if (src_ctx->sock != dst_ctx->sock ) {
419
+ if (src_ctx->back != dst_ctx->back ) {
391
420
return false ;
392
421
}
393
422
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context ;
@@ -399,7 +428,7 @@ GGML_CALL static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t b
399
428
memcpy (input.data (), &rpc_src, sizeof (rpc_src));
400
429
memcpy (input.data () + sizeof (rpc_src), &rpc_dst, sizeof (rpc_dst));
401
430
std::vector<uint8_t > output;
402
- bool status = send_rpc_cmd (ctx->sock , COPY_TENSOR, input, output);
431
+ bool status = send_rpc_cmd (ctx->back -> sock , COPY_TENSOR, input, output);
403
432
GGML_ASSERT (status);
404
433
// output serialization format: | result (1 byte) |
405
434
GGML_ASSERT (output.size () == 1 );
@@ -414,7 +443,7 @@ GGML_CALL static void ggml_backend_rpc_buffer_clear(ggml_backend_buffer_t buffer
414
443
memcpy (input.data (), &ctx->remote_ptr , sizeof (ctx->remote_ptr ));
415
444
memcpy (input.data () + sizeof (ctx->remote_ptr ), &value, sizeof (value));
416
445
std::vector<uint8_t > output;
417
- bool status = send_rpc_cmd (ctx->sock , BUFFER_CLEAR, input, output);
446
+ bool status = send_rpc_cmd (ctx->back -> sock , BUFFER_CLEAR, input, output);
418
447
GGML_ASSERT (status);
419
448
}
420
449
@@ -442,7 +471,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer
442
471
std::vector<uint8_t > input (input_size, 0 );
443
472
memcpy (input.data (), &size, sizeof (size));
444
473
std::vector<uint8_t > output;
445
- bool status = send_rpc_cmd (buft_ctx->sock , ALLOC_BUFFER, input, output);
474
+ bool status = send_rpc_cmd (buft_ctx->back -> sock , ALLOC_BUFFER, input, output);
446
475
GGML_ASSERT (status);
447
476
GGML_ASSERT (output.size () == 2 *sizeof (uint64_t ));
448
477
// output serialization format: | remote_ptr (8 bytes) | remote_size (8 bytes) |
@@ -453,8 +482,9 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer
453
482
if (remote_ptr != 0 ) {
454
483
ggml_backend_buffer_t buffer = ggml_backend_buffer_init (buft,
455
484
ggml_backend_rpc_buffer_interface,
456
- new ggml_backend_rpc_buffer_context{buft_ctx->sock , {}, remote_ptr, " RPC" },
485
+ new ggml_backend_rpc_buffer_context{buft_ctx->back , {}, remote_ptr, " RPC" },
457
486
remote_size);
487
+ buft_ctx->back ->ref_count ++;
458
488
return buffer;
459
489
} else {
460
490
return nullptr ;
@@ -508,7 +538,7 @@ GGML_CALL static bool ggml_backend_rpc_buffer_type_supports_backend(ggml_backend
508
538
}
509
539
ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context ;
510
540
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context ;
511
- return buft_ctx->sock == rpc_ctx->sock ;
541
+ return buft_ctx->back == rpc_ctx->back ;
512
542
}
513
543
514
544
static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = {
@@ -521,7 +551,6 @@ static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = {
521
551
/* .is_host = */ NULL ,
522
552
};
523
553
524
-
525
554
GGML_CALL static const char * ggml_backend_rpc_name (ggml_backend_t backend) {
526
555
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context ;
527
556
@@ -530,11 +559,10 @@ GGML_CALL static const char * ggml_backend_rpc_name(ggml_backend_t backend) {
530
559
531
560
GGML_CALL static void ggml_backend_rpc_free (ggml_backend_t backend) {
532
561
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context ;
533
- ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)rpc_ctx->buft ->context ;
534
- delete buft_ctx;
535
- delete rpc_ctx->buft ;
536
- delete rpc_ctx;
537
- delete backend;
562
+ rpc_ctx->back ->ref_count --;
563
+ if (rpc_ctx->back ->ref_count == 0 ) {
564
+ free_rpc_backend (rpc_ctx->back );
565
+ }
538
566
}
539
567
540
568
GGML_CALL static ggml_backend_buffer_type_t ggml_backend_rpc_get_default_buffer_type (ggml_backend_t backend) {
@@ -590,7 +618,7 @@ GGML_CALL static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t
590
618
std::vector<uint8_t > input;
591
619
serialize_graph (cgraph, input);
592
620
std::vector<uint8_t > output;
593
- bool status = send_rpc_cmd (rpc_ctx->sock , GRAPH_COMPUTE, input, output);
621
+ bool status = send_rpc_cmd (rpc_ctx->back -> sock , GRAPH_COMPUTE, input, output);
594
622
GGML_ASSERT (status);
595
623
GGML_ASSERT (output.size () == 1 );
596
624
return (enum ggml_status)output[0 ];
@@ -624,17 +652,9 @@ static ggml_backend_i ggml_backend_rpc_interface = {
624
652
/* .event_synchronize = */ NULL ,
625
653
};
626
654
627
- static std::unordered_map<std::string, ggml_backend_t > instances;
628
-
629
- GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type (const char * endpoint) {
630
- ggml_backend_t backend = ggml_backend_rpc_init (endpoint);
631
- return backend != nullptr ? ggml_backend_rpc_get_default_buffer_type (backend) : nullptr ;
632
- }
633
-
634
- GGML_CALL ggml_backend_t ggml_backend_rpc_init (const char * endpoint) {
635
- std::string endpoint_str (endpoint);
636
- if (instances.find (endpoint_str) != instances.end ()) {
637
- return instances[endpoint_str];
655
+ static rpc_backend_ptr create_rpc_backend (const std::string & endpoint) {
656
+ if (instances.find (endpoint) != instances.end ()) {
657
+ return instances[endpoint];
638
658
}
639
659
#ifdef _WIN32
640
660
{
@@ -645,7 +665,7 @@ GGML_CALL ggml_backend_t ggml_backend_rpc_init(const char * endpoint) {
645
665
}
646
666
}
647
667
#endif
648
- fprintf (stderr, " Connecting to %s\n " , endpoint);
668
+ fprintf (stderr, " Connecting to %s\n " , endpoint. c_str () );
649
669
std::string host;
650
670
int port;
651
671
if (!parse_endpoint (endpoint, host, port)) {
@@ -657,11 +677,12 @@ GGML_CALL ggml_backend_t ggml_backend_rpc_init(const char * endpoint) {
657
677
}
658
678
size_t alignment = get_alignment (sock);
659
679
size_t max_size = get_max_size (sock);
680
+ auto rpc_back = std::make_shared<rpc_backend>();
660
681
ggml_backend_rpc_buffer_type_context * buft_ctx = new ggml_backend_rpc_buffer_type_context {
661
- /* .sock = */ sock ,
662
- /* .name = */ " RPC" + std::to_string (sock->fd ),
682
+ /* .back = */ rpc_back ,
683
+ /* .name = */ " RPC" + std::to_string (sock->fd ),
663
684
/* .alignment = */ alignment,
664
- /* .max_size = */ max_size
685
+ /* .max_size = */ max_size
665
686
};
666
687
667
688
ggml_backend_buffer_type_t buft = new ggml_backend_buffer_type {
@@ -670,19 +691,37 @@ GGML_CALL ggml_backend_t ggml_backend_rpc_init(const char * endpoint) {
670
691
};
671
692
672
693
ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context {
673
- /* .endpoint = */ endpoint,
674
694
/* .name = */ " RPC" + std::to_string (sock->fd ),
675
- /* .sock = */ sock ,
695
+ /* .back = */ rpc_back ,
676
696
/* .buft = */ buft
677
697
};
678
698
679
- instances[endpoint] = new ggml_backend {
699
+ ggml_backend_t backend = new ggml_backend {
680
700
/* .guid = */ ggml_backend_rpc_guid (),
681
701
/* .interface = */ ggml_backend_rpc_interface,
682
702
/* .context = */ ctx
683
703
};
704
+ rpc_back->sock = sock;
705
+ rpc_back->endpoint = endpoint;
706
+ rpc_back->backend = backend;
707
+ rpc_back->ref_count = 0 ;
708
+ instances[endpoint] = rpc_back;
709
+ return rpc_back;
710
+ }
684
711
685
- return instances[endpoint];
712
+ GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type (const char * endpoint) {
713
+ auto rpc_back = create_rpc_backend (endpoint);
714
+ return rpc_back != nullptr ? ggml_backend_rpc_get_default_buffer_type (rpc_back->backend ) : nullptr ;
715
+ }
716
+
717
+ GGML_CALL ggml_backend_t ggml_backend_rpc_init (const char * endpoint) {
718
+ std::string endpoint_str (endpoint);
719
+ auto rpc_back = create_rpc_backend (endpoint_str);
720
+ if (rpc_back == nullptr ) {
721
+ return nullptr ;
722
+ }
723
+ rpc_back->ref_count ++;
724
+ return rpc_back->backend ;
686
725
}
687
726
688
727
GGML_API GGML_CALL bool ggml_backend_is_rpc (ggml_backend_t backend) {
@@ -706,14 +745,13 @@ static void get_device_memory(const std::shared_ptr<socket_t> & sock, size_t * f
706
745
}
707
746
708
747
GGML_API GGML_CALL void ggml_backend_rpc_get_device_memory (const char * endpoint, size_t * free, size_t * total) {
709
- ggml_backend_t backend = ggml_backend_rpc_init (endpoint);
710
- if (backend == nullptr ) {
748
+ auto rpc_back = create_rpc_backend (endpoint);
749
+ if (rpc_back == nullptr ) {
711
750
*free = 0 ;
712
751
*total = 0 ;
713
752
return ;
714
753
}
715
- ggml_backend_rpc_context * ctx = (ggml_backend_rpc_context *)backend->context ;
716
- get_device_memory (ctx->sock , free, total);
754
+ get_device_memory (rpc_back->sock , free, total);
717
755
}
718
756
719
757
// RPC server-side implementation
0 commit comments