|
10 | 10 | #include <string>
|
11 | 11 | #include <stdio.h>
|
12 | 12 |
|
| 13 | +struct rpc_server_params { |
| 14 | + std::string host = "0.0.0.0"; |
| 15 | + int port = 50052; |
| 16 | + size_t backend_mem = 0; |
| 17 | +}; |
| 18 | + |
| 19 | +static void print_usage(int /*argc*/, char ** argv, rpc_server_params params) { |
| 20 | + fprintf(stderr, "Usage: %s [options]\n\n", argv[0]); |
| 21 | + fprintf(stderr, "options:\n"); |
| 22 | + fprintf(stderr, " -h, --help show this help message and exit\n"); |
| 23 | + fprintf(stderr, " -H HOST, --host HOST host to bind to (default: %s)\n", params.host.c_str()); |
| 24 | + fprintf(stderr, " -p PORT, --port PORT port to bind to (default: %d)\n", params.port); |
| 25 | + fprintf(stderr, " -m MEM, --mem MEM backend memory size (in MB)\n"); |
| 26 | + fprintf(stderr, "\n"); |
| 27 | +} |
| 28 | + |
| 29 | +static bool rpc_server_params_parse(int argc, char ** argv, rpc_server_params & params) { |
| 30 | + std::string arg; |
| 31 | + for (int i = 1; i < argc; i++) { |
| 32 | + arg = argv[i]; |
| 33 | + if (arg == "-H" || arg == "--host") { |
| 34 | + if (++i >= argc) { |
| 35 | + return false; |
| 36 | + } |
| 37 | + params.host = argv[i]; |
| 38 | + } else if (arg == "-p" || arg == "--port") { |
| 39 | + if (++i >= argc) { |
| 40 | + return false; |
| 41 | + } |
| 42 | + params.port = std::stoi(argv[i]); |
| 43 | + if (params.port <= 0 || params.port > 65535) { |
| 44 | + return false; |
| 45 | + } |
| 46 | + } else if (arg == "-m" || arg == "--mem") { |
| 47 | + if (++i >= argc) { |
| 48 | + return false; |
| 49 | + } |
| 50 | + params.backend_mem = std::stoul(argv[i]) * 1024 * 1024; |
| 51 | + } else if (arg == "-h" || arg == "--help") { |
| 52 | + print_usage(argc, argv, params); |
| 53 | + exit(0); |
| 54 | + } |
| 55 | + } |
| 56 | + return true; |
| 57 | +} |
| 58 | + |
13 | 59 | static ggml_backend_t create_backend() {
|
14 | 60 | ggml_backend_t backend = NULL;
|
15 | 61 | #ifdef GGML_USE_CUDA
|
@@ -45,25 +91,25 @@ static void get_backend_memory(size_t * free_mem, size_t * total_mem) {
|
45 | 91 | }
|
46 | 92 |
|
47 | 93 | int main(int argc, char * argv[]) {
|
48 |
| - if (argc < 3) { |
49 |
| - fprintf(stderr, "Usage: %s <host> <port>\n", argv[0]); |
50 |
| - return 1; |
51 |
| - } |
52 |
| - const char * host = argv[1]; |
53 |
| - int port = std::stoi(argv[2]); |
54 |
| - if (port <= 0 || port > 65535) { |
55 |
| - fprintf(stderr, "Invalid port number: %d\n", port); |
| 94 | + rpc_server_params params; |
| 95 | + if (!rpc_server_params_parse(argc, argv, params)) { |
| 96 | + fprintf(stderr, "Invalid parameters\n"); |
56 | 97 | return 1;
|
57 | 98 | }
|
58 | 99 | ggml_backend_t backend = create_backend();
|
59 | 100 | if (!backend) {
|
60 | 101 | fprintf(stderr, "Failed to create backend\n");
|
61 | 102 | return 1;
|
62 | 103 | }
|
63 |
| - printf("Starting RPC server on %s:%d\n", host, port); |
| 104 | + std::string endpoint = params.host + ":" + std::to_string(params.port); |
64 | 105 | size_t free_mem, total_mem;
|
65 |
| - get_backend_memory(&free_mem, &total_mem); |
66 |
| - std::string endpoint = std::string(host) + ":" + std::to_string(port); |
| 106 | + if (params.backend_mem > 0) { |
| 107 | + free_mem = params.backend_mem; |
| 108 | + total_mem = params.backend_mem; |
| 109 | + } else { |
| 110 | + get_backend_memory(&free_mem, &total_mem); |
| 111 | + } |
| 112 | + printf("Starting RPC server on %s, backend memory: %zu MB\n", endpoint.c_str(), free_mem / (1024 * 1024)); |
67 | 113 | start_rpc_server(backend, endpoint.c_str(), free_mem, total_mem);
|
68 | 114 | ggml_backend_free(backend);
|
69 | 115 | return 0;
|
|
0 commit comments