25
25
# include < netdb.h>
26
26
# include < unistd.h>
27
27
#endif
28
- #include < string.h >
28
+ #include < cstring >
29
29
30
30
#define UNUSED GGML_UNUSED
31
31
@@ -630,22 +630,6 @@ static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, g
630
630
return (enum ggml_status)output[0 ];
631
631
}
632
632
633
- static bool ggml_backend_rpc_supports_op (ggml_backend_t backend, const ggml_tensor * op) {
634
- UNUSED (backend);
635
- UNUSED (op);
636
- // TODO: call the remote backend and cache the results
637
- return true ;
638
- }
639
-
640
- static bool ggml_backend_rpc_supports_buft (ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
641
- if (!buft || buft->iface .get_name != ggml_backend_rpc_buffer_type_name) {
642
- return false ;
643
- }
644
- ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context ;
645
- ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context ;
646
- return buft_ctx->endpoint == rpc_ctx->endpoint ;
647
- }
648
-
649
633
static ggml_backend_i ggml_backend_rpc_interface = {
650
634
/* .get_name = */ ggml_backend_rpc_name,
651
635
/* .free = */ ggml_backend_rpc_free,
@@ -659,8 +643,8 @@ static ggml_backend_i ggml_backend_rpc_interface = {
659
643
/* .graph_plan_update = */ NULL ,
660
644
/* .graph_plan_compute = */ NULL ,
661
645
/* .graph_compute = */ ggml_backend_rpc_graph_compute,
662
- /* .supports_op = */ ggml_backend_rpc_supports_op ,
663
- /* .supports_buft = */ ggml_backend_rpc_supports_buft ,
646
+ /* .supports_op = */ NULL ,
647
+ /* .supports_buft = */ NULL ,
664
648
/* .offload_op = */ NULL ,
665
649
/* .event_record = */ NULL ,
666
650
/* .event_wait = */ NULL ,
@@ -691,7 +675,7 @@ GGML_API ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * en
691
675
692
676
ggml_backend_buffer_type_t buft = new ggml_backend_buffer_type {
693
677
/* .iface = */ ggml_backend_rpc_buffer_type_interface,
694
- /* .device = */ nullptr ,
678
+ /* .device = */ ggml_backend_rpc_add_device (endpoint) ,
695
679
/* .context = */ buft_ctx
696
680
};
697
681
buft_map[endpoint] = buft;
@@ -707,7 +691,7 @@ ggml_backend_t ggml_backend_rpc_init(const char * endpoint) {
707
691
ggml_backend_t backend = new ggml_backend {
708
692
/* .guid = */ ggml_backend_rpc_guid (),
709
693
/* .interface = */ ggml_backend_rpc_interface,
710
- /* .device = */ nullptr ,
694
+ /* .device = */ ggml_backend_rpc_add_device (endpoint) ,
711
695
/* .context = */ ctx
712
696
};
713
697
return backend;
@@ -1189,7 +1173,7 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
1189
1173
}
1190
1174
}
1191
1175
1192
- void start_rpc_server (ggml_backend_t backend, const char * endpoint, size_t free_mem, size_t total_mem) {
1176
+ void ggml_backend_rpc_start_server (ggml_backend_t backend, const char * endpoint, size_t free_mem, size_t total_mem) {
1193
1177
std::string host;
1194
1178
int port;
1195
1179
if (!parse_endpoint (endpoint, host, port)) {
@@ -1226,3 +1210,179 @@ void start_rpc_server(ggml_backend_t backend, const char * endpoint, size_t free
1226
1210
WSACleanup ();
1227
1211
#endif
1228
1212
}
1213
+
1214
+ // device interface
1215
+
1216
+ struct ggml_backend_rpc_device_context {
1217
+ std::string endpoint;
1218
+ std::string name;
1219
+ };
1220
+
1221
+ static const char * ggml_backend_rpc_device_get_name (ggml_backend_dev_t dev) {
1222
+ ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context ;
1223
+
1224
+ return ctx->name .c_str ();
1225
+ }
1226
+
1227
+ static const char * ggml_backend_rpc_device_get_description (ggml_backend_dev_t dev) {
1228
+ ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context ;
1229
+
1230
+ return ctx->name .c_str ();
1231
+ }
1232
+
1233
+ static void ggml_backend_rpc_device_get_memory (ggml_backend_dev_t dev, size_t * free, size_t * total) {
1234
+ ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context ;
1235
+
1236
+ ggml_backend_rpc_get_device_memory (ctx->endpoint .c_str (), free, total);
1237
+
1238
+ UNUSED (dev);
1239
+ }
1240
+
1241
+ static enum ggml_backend_dev_type ggml_backend_rpc_device_get_type (ggml_backend_dev_t dev) {
1242
+ // TODO: obtain value from the server
1243
+ return GGML_BACKEND_DEVICE_TYPE_GPU_FULL;
1244
+
1245
+ UNUSED (dev);
1246
+ }
1247
+
1248
+ static void ggml_backend_rpc_device_get_props (ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
1249
+ props->name = ggml_backend_rpc_device_get_name (dev);
1250
+ props->description = ggml_backend_rpc_device_get_description (dev);
1251
+ props->type = ggml_backend_rpc_device_get_type (dev);
1252
+ ggml_backend_rpc_device_get_memory (dev, &props->memory_free , &props->memory_total );
1253
+ props->caps = {
1254
+ /* .async = */ false ,
1255
+ /* .host_buffer = */ false ,
1256
+ /* .buffer_from_host_ptr = */ false ,
1257
+ /* .events = */ false ,
1258
+ };
1259
+ }
1260
+
1261
+ static ggml_backend_t ggml_backend_rpc_device_init (ggml_backend_dev_t dev, const char * params) {
1262
+ ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context ;
1263
+
1264
+ return ggml_backend_rpc_init (ctx->endpoint .c_str ());
1265
+
1266
+ UNUSED (params);
1267
+ }
1268
+
1269
+ static ggml_backend_buffer_type_t ggml_backend_rpc_device_get_buffer_type (ggml_backend_dev_t dev) {
1270
+ ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context ;
1271
+
1272
+ return ggml_backend_rpc_buffer_type (ctx->endpoint .c_str ());
1273
+
1274
+ UNUSED (dev);
1275
+ }
1276
+
1277
+ static ggml_backend_buffer_t ggml_backend_rpc_device_buffer_from_ptr (ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) {
1278
+ return ggml_backend_cpu_buffer_from_ptr (ptr, size);
1279
+
1280
+ UNUSED (dev);
1281
+ UNUSED (max_tensor_size);
1282
+ }
1283
+
1284
+ static bool ggml_backend_rpc_device_supports_op (ggml_backend_dev_t dev, const struct ggml_tensor * op) {
1285
+ UNUSED (dev);
1286
+ UNUSED (op);
1287
+ // TODO: call the remote backend and cache the results
1288
+ return true ;
1289
+ }
1290
+
1291
+ static bool ggml_backend_rpc_device_supports_buft (ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
1292
+ if (!buft || buft->iface .get_name != ggml_backend_rpc_buffer_type_name) {
1293
+ return false ;
1294
+ }
1295
+ ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context ;
1296
+ ggml_backend_rpc_device_context * dev_ctx = (ggml_backend_rpc_device_context *)dev->context ;
1297
+ return buft_ctx->endpoint == dev_ctx->endpoint ;
1298
+ }
1299
+
1300
+ static const struct ggml_backend_device_i ggml_backend_rpc_device_i = {
1301
+ /* .get_name = */ ggml_backend_rpc_device_get_name,
1302
+ /* .get_description = */ ggml_backend_rpc_device_get_description,
1303
+ /* .get_memory = */ ggml_backend_rpc_device_get_memory,
1304
+ /* .get_type = */ ggml_backend_rpc_device_get_type,
1305
+ /* .get_props = */ ggml_backend_rpc_device_get_props,
1306
+ /* .init_backend = */ ggml_backend_rpc_device_init,
1307
+ /* .get_buffer_type = */ ggml_backend_rpc_device_get_buffer_type,
1308
+ /* .get_host_buffer_type = */ NULL ,
1309
+ /* .buffer_from_host_ptr = */ ggml_backend_rpc_device_buffer_from_ptr,
1310
+ /* .supports_op = */ ggml_backend_rpc_device_supports_op,
1311
+ /* .supports_buft = */ ggml_backend_rpc_device_supports_buft,
1312
+ /* .offload_op = */ NULL ,
1313
+ /* .event_new = */ NULL ,
1314
+ /* .event_free = */ NULL ,
1315
+ /* .event_synchronize = */ NULL ,
1316
+ };
1317
+
1318
+ // backend reg interface
1319
+
1320
+ static const char * ggml_backend_rpc_reg_get_name (ggml_backend_reg_t reg) {
1321
+ return " RPC" ;
1322
+
1323
+ UNUSED (reg);
1324
+ }
1325
+
1326
+ static size_t ggml_backend_rpc_reg_get_device_count (ggml_backend_reg_t reg) {
1327
+ return 0 ;
1328
+
1329
+ UNUSED (reg);
1330
+ }
1331
+
1332
+ static ggml_backend_dev_t ggml_backend_rpc_reg_get_device (ggml_backend_reg_t reg, size_t index) {
1333
+ GGML_ABORT (" The RPC backend does not have enumerated devices - use ggml_backend_add_device instead" );
1334
+
1335
+ UNUSED (reg);
1336
+ UNUSED (index);
1337
+ }
1338
+
1339
+ static void * ggml_backend_rpc_get_proc_address (ggml_backend_reg_t reg, const char * name) {
1340
+ if (std::strcmp (name, " ggml_backend_rpc_add_device" ) == 0 ) {
1341
+ return (void *)ggml_backend_rpc_add_device;
1342
+ }
1343
+ return NULL ;
1344
+
1345
+ UNUSED (reg);
1346
+ }
1347
+
1348
+ static const struct ggml_backend_reg_i ggml_backend_rpc_reg_i = {
1349
+ /* .get_name = */ ggml_backend_rpc_reg_get_name,
1350
+ /* .get_device_count = */ ggml_backend_rpc_reg_get_device_count,
1351
+ /* .get_device = */ ggml_backend_rpc_reg_get_device,
1352
+ /* .get_proc_address = */ ggml_backend_rpc_get_proc_address,
1353
+ };
1354
+
1355
+ ggml_backend_reg_t ggml_backend_rpc_reg (void ) {
1356
+ static struct ggml_backend_reg ggml_backend_rpc_reg = {
1357
+ /* .iface = */ ggml_backend_rpc_reg_i,
1358
+ /* .context = */ NULL ,
1359
+ };
1360
+
1361
+ return &ggml_backend_rpc_reg;
1362
+ }
1363
+
1364
+ ggml_backend_dev_t ggml_backend_rpc_add_device (const char * endpoint) {
1365
+ static std::unordered_map<std::string, ggml_backend_dev_t > dev_map;
1366
+
1367
+ static std::mutex mutex;
1368
+ std::lock_guard<std::mutex> lock (mutex);
1369
+
1370
+ if (dev_map.find (endpoint) != dev_map.end ()) {
1371
+ return dev_map[endpoint];
1372
+ }
1373
+
1374
+ ggml_backend_rpc_device_context * ctx = new ggml_backend_rpc_device_context {
1375
+ /* .endpoint = */ endpoint,
1376
+ /* .name = */ " RPC[" + std::string (endpoint) + " ]" ,
1377
+ };
1378
+
1379
+ ggml_backend_dev_t dev = new ggml_backend_device {
1380
+ /* .iface = */ ggml_backend_rpc_device_i,
1381
+ /* .reg = */ ggml_backend_rpc_reg (),
1382
+ /* .context = */ ctx,
1383
+ };
1384
+
1385
+ dev_map[endpoint] = dev;
1386
+
1387
+ return dev;
1388
+ }
0 commit comments