Skip to content

Commit 5aeb71f

Browse files
slarenarthw
authored andcommitted
rpc : add backend registry / device interfaces (ggml-org#9812)
* rpc : add backend registry / device interfaces * llama : add llama_supports_rpc API * ggml_backend_rpc_start_rpc_server -> ggml_backend_rpc_start_server
1 parent 0e9d3c9 commit 5aeb71f

File tree

8 files changed

+247
-88
lines changed

8 files changed

+247
-88
lines changed

common/arg.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1353,15 +1353,15 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
13531353
params.image.emplace_back(value);
13541354
}
13551355
).set_examples({LLAMA_EXAMPLE_LLAVA}));
1356-
#ifdef GGML_USE_RPC
1357-
add_opt(llama_arg(
1358-
{"--rpc"}, "SERVERS",
1359-
"comma separated list of RPC servers",
1360-
[](gpt_params & params, const std::string & value) {
1361-
params.rpc_servers = value;
1362-
}
1363-
).set_env("LLAMA_ARG_RPC"));
1364-
#endif
1356+
if (llama_supports_rpc()) {
1357+
add_opt(llama_arg(
1358+
{"--rpc"}, "SERVERS",
1359+
"comma separated list of RPC servers",
1360+
[](gpt_params & params, const std::string & value) {
1361+
params.rpc_servers = value;
1362+
}
1363+
).set_env("LLAMA_ARG_RPC"));
1364+
}
13651365
add_opt(llama_arg(
13661366
{"--mlock"},
13671367
"force system to keep model in RAM rather than swapping or compressing",

examples/llama-bench/llama-bench.cpp

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -304,9 +304,9 @@ static void print_usage(int /* argc */, char ** argv) {
304304
printf(" --cpu-strict <0|1> (default: %s)\n", join(cmd_params_defaults.cpu_strict, ",").c_str());
305305
printf(" --poll <0...100> (default: %s)\n", join(cmd_params_defaults.poll, ",").c_str());
306306
printf(" -ngl, --n-gpu-layers <n> (default: %s)\n", join(cmd_params_defaults.n_gpu_layers, ",").c_str());
307-
#ifdef GGML_USE_RPC
308-
printf(" -rpc, --rpc <rpc_servers> (default: %s)\n", join(cmd_params_defaults.rpc_servers, ",").c_str());
309-
#endif
307+
if (llama_supports_rpc()) {
308+
printf(" -rpc, --rpc <rpc_servers> (default: %s)\n", join(cmd_params_defaults.rpc_servers, ",").c_str());
309+
}
310310
printf(" -sm, --split-mode <none|layer|row> (default: %s)\n", join(transform_to_str(cmd_params_defaults.split_mode, split_mode_str), ",").c_str());
311311
printf(" -mg, --main-gpu <i> (default: %s)\n", join(cmd_params_defaults.main_gpu, ",").c_str());
312312
printf(" -nkvo, --no-kv-offload <0|1> (default: %s)\n", join(cmd_params_defaults.no_kv_offload, ",").c_str());
@@ -497,14 +497,12 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
497497
}
498498
auto p = string_split<int>(argv[i], split_delim);
499499
params.n_gpu_layers.insert(params.n_gpu_layers.end(), p.begin(), p.end());
500-
#ifdef GGML_USE_RPC
501-
} else if (arg == "-rpc" || arg == "--rpc") {
500+
} else if (llama_supports_rpc() && (arg == "-rpc" || arg == "--rpc")) {
502501
if (++i >= argc) {
503502
invalid_param = true;
504503
break;
505504
}
506505
params.rpc_servers.push_back(argv[i]);
507-
#endif
508506
} else if (arg == "-sm" || arg == "--split-mode") {
509507
if (++i >= argc) {
510508
invalid_param = true;

examples/rpc/rpc-server.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ int main(int argc, char * argv[]) {
151151
get_backend_memory(&free_mem, &total_mem);
152152
}
153153
printf("Starting RPC server on %s, backend memory: %zu MB\n", endpoint.c_str(), free_mem / (1024 * 1024));
154-
start_rpc_server(backend, endpoint.c_str(), free_mem, total_mem);
154+
ggml_backend_rpc_start_server(backend, endpoint.c_str(), free_mem, total_mem);
155155
ggml_backend_free(backend);
156156
return 0;
157157
}

ggml/include/ggml-rpc.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,11 @@ GGML_API ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * en
1717

1818
GGML_API void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total);
1919

20-
GGML_API void start_rpc_server(ggml_backend_t backend, const char * endpoint, size_t free_mem, size_t total_mem);
20+
GGML_API void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint, size_t free_mem, size_t total_mem);
21+
22+
GGML_API ggml_backend_reg_t ggml_backend_rpc_reg(void);
23+
24+
GGML_API ggml_backend_dev_t ggml_backend_rpc_add_device(const char * endpoint);
2125

2226
#ifdef __cplusplus
2327
}

ggml/src/ggml-backend.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -542,6 +542,10 @@ void * ggml_backend_reg_get_proc_address(ggml_backend_reg_t reg, const char * na
542542
#include "ggml-blas.h"
543543
#endif
544544

545+
#ifdef GGML_USE_RPC
546+
#include "ggml-rpc.h"
547+
#endif
548+
545549
struct ggml_backend_registry {
546550
std::vector<ggml_backend_reg_t> backends;
547551
std::vector<ggml_backend_dev_t> devices;
@@ -556,6 +560,9 @@ struct ggml_backend_registry {
556560
#ifdef GGML_USE_BLAS
557561
register_backend(ggml_backend_blas_reg());
558562
#endif
563+
#ifdef GGML_USE_RPC
564+
register_backend(ggml_backend_rpc_reg());
565+
#endif
559566

560567
// TODO: sycl, vulkan, kompute, cann
561568

ggml/src/ggml-rpc.cpp

Lines changed: 182 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
# include <netdb.h>
2626
# include <unistd.h>
2727
#endif
28-
#include <string.h>
28+
#include <cstring>
2929

3030
#define UNUSED GGML_UNUSED
3131

@@ -630,22 +630,6 @@ static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, g
630630
return (enum ggml_status)output[0];
631631
}
632632

633-
static bool ggml_backend_rpc_supports_op(ggml_backend_t backend, const ggml_tensor * op) {
634-
UNUSED(backend);
635-
UNUSED(op);
636-
//TODO: call the remote backend and cache the results
637-
return true;
638-
}
639-
640-
static bool ggml_backend_rpc_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
641-
if (!buft || buft->iface.get_name != ggml_backend_rpc_buffer_type_name) {
642-
return false;
643-
}
644-
ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
645-
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
646-
return buft_ctx->endpoint == rpc_ctx->endpoint;
647-
}
648-
649633
static ggml_backend_i ggml_backend_rpc_interface = {
650634
/* .get_name = */ ggml_backend_rpc_name,
651635
/* .free = */ ggml_backend_rpc_free,
@@ -659,8 +643,8 @@ static ggml_backend_i ggml_backend_rpc_interface = {
659643
/* .graph_plan_update = */ NULL,
660644
/* .graph_plan_compute = */ NULL,
661645
/* .graph_compute = */ ggml_backend_rpc_graph_compute,
662-
/* .supports_op = */ ggml_backend_rpc_supports_op,
663-
/* .supports_buft = */ ggml_backend_rpc_supports_buft,
646+
/* .supports_op = */ NULL,
647+
/* .supports_buft = */ NULL,
664648
/* .offload_op = */ NULL,
665649
/* .event_record = */ NULL,
666650
/* .event_wait = */ NULL,
@@ -691,7 +675,7 @@ GGML_API ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * en
691675

692676
ggml_backend_buffer_type_t buft = new ggml_backend_buffer_type {
693677
/* .iface = */ ggml_backend_rpc_buffer_type_interface,
694-
/* .device = */ nullptr,
678+
/* .device = */ ggml_backend_rpc_add_device(endpoint),
695679
/* .context = */ buft_ctx
696680
};
697681
buft_map[endpoint] = buft;
@@ -707,7 +691,7 @@ ggml_backend_t ggml_backend_rpc_init(const char * endpoint) {
707691
ggml_backend_t backend = new ggml_backend {
708692
/* .guid = */ ggml_backend_rpc_guid(),
709693
/* .interface = */ ggml_backend_rpc_interface,
710-
/* .device = */ nullptr,
694+
/* .device = */ ggml_backend_rpc_add_device(endpoint),
711695
/* .context = */ ctx
712696
};
713697
return backend;
@@ -1189,7 +1173,7 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
11891173
}
11901174
}
11911175

1192-
void start_rpc_server(ggml_backend_t backend, const char * endpoint, size_t free_mem, size_t total_mem) {
1176+
void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint, size_t free_mem, size_t total_mem) {
11931177
std::string host;
11941178
int port;
11951179
if (!parse_endpoint(endpoint, host, port)) {
@@ -1226,3 +1210,179 @@ void start_rpc_server(ggml_backend_t backend, const char * endpoint, size_t free
12261210
WSACleanup();
12271211
#endif
12281212
}
1213+
1214+
// device interface
1215+
1216+
struct ggml_backend_rpc_device_context {
1217+
std::string endpoint;
1218+
std::string name;
1219+
};
1220+
1221+
static const char * ggml_backend_rpc_device_get_name(ggml_backend_dev_t dev) {
1222+
ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
1223+
1224+
return ctx->name.c_str();
1225+
}
1226+
1227+
static const char * ggml_backend_rpc_device_get_description(ggml_backend_dev_t dev) {
1228+
ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
1229+
1230+
return ctx->name.c_str();
1231+
}
1232+
1233+
static void ggml_backend_rpc_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
1234+
ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
1235+
1236+
ggml_backend_rpc_get_device_memory(ctx->endpoint.c_str(), free, total);
1237+
1238+
UNUSED(dev);
1239+
}
1240+
1241+
static enum ggml_backend_dev_type ggml_backend_rpc_device_get_type(ggml_backend_dev_t dev) {
1242+
// TODO: obtain value from the server
1243+
return GGML_BACKEND_DEVICE_TYPE_GPU_FULL;
1244+
1245+
UNUSED(dev);
1246+
}
1247+
1248+
static void ggml_backend_rpc_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
1249+
props->name = ggml_backend_rpc_device_get_name(dev);
1250+
props->description = ggml_backend_rpc_device_get_description(dev);
1251+
props->type = ggml_backend_rpc_device_get_type(dev);
1252+
ggml_backend_rpc_device_get_memory(dev, &props->memory_free, &props->memory_total);
1253+
props->caps = {
1254+
/* .async = */ false,
1255+
/* .host_buffer = */ false,
1256+
/* .buffer_from_host_ptr = */ false,
1257+
/* .events = */ false,
1258+
};
1259+
}
1260+
1261+
static ggml_backend_t ggml_backend_rpc_device_init(ggml_backend_dev_t dev, const char * params) {
1262+
ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
1263+
1264+
return ggml_backend_rpc_init(ctx->endpoint.c_str());
1265+
1266+
UNUSED(params);
1267+
}
1268+
1269+
static ggml_backend_buffer_type_t ggml_backend_rpc_device_get_buffer_type(ggml_backend_dev_t dev) {
1270+
ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
1271+
1272+
return ggml_backend_rpc_buffer_type(ctx->endpoint.c_str());
1273+
1274+
UNUSED(dev);
1275+
}
1276+
1277+
static ggml_backend_buffer_t ggml_backend_rpc_device_buffer_from_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) {
1278+
return ggml_backend_cpu_buffer_from_ptr(ptr, size);
1279+
1280+
UNUSED(dev);
1281+
UNUSED(max_tensor_size);
1282+
}
1283+
1284+
static bool ggml_backend_rpc_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
1285+
UNUSED(dev);
1286+
UNUSED(op);
1287+
//TODO: call the remote backend and cache the results
1288+
return true;
1289+
}
1290+
1291+
static bool ggml_backend_rpc_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
1292+
if (!buft || buft->iface.get_name != ggml_backend_rpc_buffer_type_name) {
1293+
return false;
1294+
}
1295+
ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
1296+
ggml_backend_rpc_device_context * dev_ctx = (ggml_backend_rpc_device_context *)dev->context;
1297+
return buft_ctx->endpoint == dev_ctx->endpoint;
1298+
}
1299+
1300+
static const struct ggml_backend_device_i ggml_backend_rpc_device_i = {
1301+
/* .get_name = */ ggml_backend_rpc_device_get_name,
1302+
/* .get_description = */ ggml_backend_rpc_device_get_description,
1303+
/* .get_memory = */ ggml_backend_rpc_device_get_memory,
1304+
/* .get_type = */ ggml_backend_rpc_device_get_type,
1305+
/* .get_props = */ ggml_backend_rpc_device_get_props,
1306+
/* .init_backend = */ ggml_backend_rpc_device_init,
1307+
/* .get_buffer_type = */ ggml_backend_rpc_device_get_buffer_type,
1308+
/* .get_host_buffer_type = */ NULL,
1309+
/* .buffer_from_host_ptr = */ ggml_backend_rpc_device_buffer_from_ptr,
1310+
/* .supports_op = */ ggml_backend_rpc_device_supports_op,
1311+
/* .supports_buft = */ ggml_backend_rpc_device_supports_buft,
1312+
/* .offload_op = */ NULL,
1313+
/* .event_new = */ NULL,
1314+
/* .event_free = */ NULL,
1315+
/* .event_synchronize = */ NULL,
1316+
};
1317+
1318+
// backend reg interface
1319+
1320+
static const char * ggml_backend_rpc_reg_get_name(ggml_backend_reg_t reg) {
1321+
return "RPC";
1322+
1323+
UNUSED(reg);
1324+
}
1325+
1326+
static size_t ggml_backend_rpc_reg_get_device_count(ggml_backend_reg_t reg) {
1327+
return 0;
1328+
1329+
UNUSED(reg);
1330+
}
1331+
1332+
static ggml_backend_dev_t ggml_backend_rpc_reg_get_device(ggml_backend_reg_t reg, size_t index) {
1333+
GGML_ABORT("The RPC backend does not have enumerated devices - use ggml_backend_add_device instead");
1334+
1335+
UNUSED(reg);
1336+
UNUSED(index);
1337+
}
1338+
1339+
static void * ggml_backend_rpc_get_proc_address(ggml_backend_reg_t reg, const char * name) {
1340+
if (std::strcmp(name, "ggml_backend_rpc_add_device") == 0) {
1341+
return (void *)ggml_backend_rpc_add_device;
1342+
}
1343+
return NULL;
1344+
1345+
UNUSED(reg);
1346+
}
1347+
1348+
static const struct ggml_backend_reg_i ggml_backend_rpc_reg_i = {
1349+
/* .get_name = */ ggml_backend_rpc_reg_get_name,
1350+
/* .get_device_count = */ ggml_backend_rpc_reg_get_device_count,
1351+
/* .get_device = */ ggml_backend_rpc_reg_get_device,
1352+
/* .get_proc_address = */ ggml_backend_rpc_get_proc_address,
1353+
};
1354+
1355+
ggml_backend_reg_t ggml_backend_rpc_reg(void) {
1356+
static struct ggml_backend_reg ggml_backend_rpc_reg = {
1357+
/* .iface = */ ggml_backend_rpc_reg_i,
1358+
/* .context = */ NULL,
1359+
};
1360+
1361+
return &ggml_backend_rpc_reg;
1362+
}
1363+
1364+
ggml_backend_dev_t ggml_backend_rpc_add_device(const char * endpoint) {
1365+
static std::unordered_map<std::string, ggml_backend_dev_t> dev_map;
1366+
1367+
static std::mutex mutex;
1368+
std::lock_guard<std::mutex> lock(mutex);
1369+
1370+
if (dev_map.find(endpoint) != dev_map.end()) {
1371+
return dev_map[endpoint];
1372+
}
1373+
1374+
ggml_backend_rpc_device_context * ctx = new ggml_backend_rpc_device_context {
1375+
/* .endpoint = */ endpoint,
1376+
/* .name = */ "RPC[" + std::string(endpoint) + "]",
1377+
};
1378+
1379+
ggml_backend_dev_t dev = new ggml_backend_device {
1380+
/* .iface = */ ggml_backend_rpc_device_i,
1381+
/* .reg = */ ggml_backend_rpc_reg(),
1382+
/* .context = */ ctx,
1383+
};
1384+
1385+
dev_map[endpoint] = dev;
1386+
1387+
return dev;
1388+
}

include/llama.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,7 @@ extern "C" {
433433
LLAMA_API bool llama_supports_mmap (void);
434434
LLAMA_API bool llama_supports_mlock (void);
435435
LLAMA_API bool llama_supports_gpu_offload(void);
436+
LLAMA_API bool llama_supports_rpc (void);
436437

437438
LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx);
438439
LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx);

0 commit comments

Comments
 (0)