Skip to content

Commit 2ef868d

Browse files
committed
Add '--rpc' command line option
1 parent 89f4d6e commit 2ef868d

File tree

7 files changed

+45
-10
lines changed

7 files changed

+45
-10
lines changed

common/common.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -999,6 +999,14 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
999999
#endif // GGML_USE_CUDA_SYCL_VULKAN
10001000
return true;
10011001
}
1002+
if (arg == "--rpc") {
1003+
if (++i >= argc) {
1004+
invalid_param = true;
1005+
return true;
1006+
}
1007+
params.rpc_servers = argv[i];
1008+
return true;
1009+
}
10021010
if (arg == "--no-mmap") {
10031011
params.use_mmap = false;
10041012
return true;
@@ -1507,6 +1515,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
15071515
printf(" -mg i, --main-gpu i the GPU to use for the model (with split-mode = none),\n");
15081516
printf(" or for intermediate results and KV (with split-mode = row) (default: %d)\n", params.main_gpu);
15091517
}
1518+
printf(" --rpc SERVERS comma separated list of RPC servers\n");
15101519
printf(" --verbose-prompt print a verbose prompt before generation (default: %s)\n", params.verbose_prompt ? "true" : "false");
15111520
printf(" --no-display-prompt don't print prompt at generation (default: %s)\n", !params.display_prompt ? "true" : "false");
15121521
printf(" -gan N, --grp-attn-n N\n");

common/common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ struct gpt_params {
8080
float yarn_beta_slow = 1.0f; // YaRN high correction dim
8181
int32_t yarn_orig_ctx = 0; // YaRN original context length
8282
float defrag_thold = -1.0f; // KV cache defragmentation threshold
83+
std::string rpc_servers = ""; // comma separated list of RPC servers
8384

8485
ggml_backend_sched_eval_callback cb_eval = nullptr;
8586
void * cb_eval_user_data = nullptr;

examples/main/main.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@ int main(int argc, char ** argv) {
187187
LOG("%s: llama backend init\n", __func__);
188188
llama_backend_init();
189189
llama_numa_init(params.numa);
190+
llama_rpc_init(params.rpc_servers.empty() ? nullptr : params.rpc_servers.c_str());
190191

191192
llama_model * model;
192193
llama_context * ctx;

ggml-rpc.cpp

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
#include "ggml-backend-impl.h"
44
#include "ggml-rpc.grpc.pb.h"
55

6-
#ifdef GGML_USE_CUBLAS
6+
#ifdef GGML_USE_CUDA
77
#include "ggml-cuda.h"
88
#endif
99

@@ -129,7 +129,9 @@ static ggml_tensor * deserialize_tensor(struct ggml_context * ctx, const ggml::T
129129

130130
GGML_CALL static void ggml_backend_rpc_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
131131
UNUSED(buffer);
132-
GGML_ASSERT(!ggml_is_quantized(tensor->type) && "quantized tensors not supported");
132+
if (ggml_is_quantized(tensor->type)) {
133+
GGML_ASSERT(tensor->ne[0] % 512 == 0 && "unsupported quantized tensor");
134+
}
133135
}
134136

135137
GGML_CALL static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
@@ -344,11 +346,20 @@ static ggml_backend_i ggml_backend_rpc_interface = {
344346
/* .event_synchronize = */ NULL,
345347
};
346348

347-
// TODO: this should be read from the command line or some configuration file
348-
static std::vector<std::string> SERVER_ENDPOINTS = {
349-
"localhost:50051",
350-
"localhost:50052",
351-
};
349+
static std::vector<std::string> endpoints;
350+
351+
GGML_API GGML_CALL void ggml_rpc_init(const char * rpc_servers) {
352+
endpoints.clear();
353+
GGML_ASSERT(rpc_servers != NULL);
354+
std::string servers(rpc_servers);
355+
size_t pos = 0;
356+
while ((pos = servers.find(",")) != std::string::npos) {
357+
std::string server = servers.substr(0, pos);
358+
endpoints.push_back(server);
359+
servers.erase(0, pos + 1);
360+
}
361+
endpoints.push_back(servers);
362+
}
352363

353364
static ggml_backend_t instances[GGML_RPC_MAX_SERVERS] = {0};
354365

@@ -364,7 +375,7 @@ GGML_CALL ggml_backend_t ggml_backend_rpc_init(int server_id) {
364375
if (instances[server_id]) {
365376
return instances[server_id];
366377
}
367-
std::string endpoint = SERVER_ENDPOINTS[server_id];
378+
std::string endpoint = endpoints[server_id];
368379
GGML_PRINT_DEBUG("Connecting to %s\n", endpoint.c_str());
369380
auto channel = grpc::CreateChannel(endpoint, grpc::InsecureChannelCredentials());
370381
std::shared_ptr<ggml::Backend::Stub> stub = ggml::Backend::NewStub(channel);
@@ -400,14 +411,14 @@ GGML_API GGML_CALL bool ggml_backend_is_rpc(ggml_backend_t backend) {
400411
}
401412

402413
GGML_API GGML_CALL int ggml_backend_rpc_get_server_count(void) {
403-
return SERVER_ENDPOINTS.size();
414+
return endpoints.size();
404415
}
405416

406417
// Server-side implementation of the RPC backend
407418

408419
BackendImpl::BackendImpl() {
409420
// the RPC backend simply delegates to one of the existing backends
410-
#ifdef GGML_USE_CUBLAS
421+
#ifdef GGML_USE_CUDA
411422
fprintf(stderr, "%s: using CUDA backend\n", __func__);
412423
backend = ggml_backend_cuda_init(0); // init device 0
413424
if (!backend) {

ggml-rpc.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ extern "C" {
1010

1111
#define GGML_RPC_MAX_SERVERS 16
1212

13+
GGML_API GGML_CALL void ggml_rpc_init(const char * rpc_servers);
14+
1315
// backend API
1416
GGML_API GGML_CALL ggml_backend_t ggml_backend_rpc_init(int server_id);
1517
GGML_API GGML_CALL bool ggml_backend_is_rpc(ggml_backend_t backend);

llama.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14957,6 +14957,16 @@ void llama_numa_init(enum ggml_numa_strategy numa) {
1495714957
}
1495814958
}
1495914959

14960+
void llama_rpc_init(const char * rpc_servers) {
14961+
#ifdef GGML_USE_RPC
14962+
ggml_rpc_init(rpc_servers);
14963+
#else
14964+
if (rpc_servers != nullptr) {
14965+
LLAMA_LOG_WARN("%s: RPC support is not enabled in this build\n", __func__);
14966+
}
14967+
#endif
14968+
}
14969+
1496014970
void llama_backend_free(void) {
1496114971
#ifdef GGML_USE_MPI
1496214972
ggml_mpi_backend_free();

llama.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,7 @@ extern "C" {
358358

359359
//optional:
360360
LLAMA_API void llama_numa_init(enum ggml_numa_strategy numa);
361+
LLAMA_API void llama_rpc_init(const char * rpc_servers);
361362

362363
// Call once at the end of the program - currently only used for MPI
363364
LLAMA_API void llama_backend_free(void);

0 commit comments

Comments
 (0)