Skip to content

Commit e1fa956

Browse files
authored
server : add SSL support (#5926)
* add cmake build toggle to enable ssl support in server Signed-off-by: Gabe Goodhart <[email protected]> * add flags for ssl key/cert files and use SSLServer if set All SSL setup is hidden behind CPPHTTPLIB_OPENSSL_SUPPORT in the same way that the base httlib hides the SSL support Signed-off-by: Gabe Goodhart <[email protected]> * Update readme for SSL support in server Signed-off-by: Gabe Goodhart <[email protected]> * Add LLAMA_SERVER_SSL variable setup to top-level Makefile Signed-off-by: Gabe Goodhart <[email protected]> --------- Signed-off-by: Gabe Goodhart <[email protected]>
1 parent fd72d2d commit e1fa956

File tree

4 files changed

+110
-34
lines changed

4 files changed

+110
-34
lines changed

Makefile

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,10 @@ ifdef LLAMA_SERVER_VERBOSE
201201
MK_CPPFLAGS += -DSERVER_VERBOSE=$(LLAMA_SERVER_VERBOSE)
202202
endif
203203

204+
ifdef LLAMA_SERVER_SSL
205+
MK_CPPFLAGS += -DCPPHTTPLIB_OPENSSL_SUPPORT
206+
MK_LDFLAGS += -lssl -lcrypto
207+
endif
204208

205209
ifdef LLAMA_CODE_COVERAGE
206210
MK_CXXFLAGS += -fprofile-arcs -ftest-coverage -dumpbase ''

examples/server/CMakeLists.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,18 @@
11
set(TARGET server)
22
option(LLAMA_SERVER_VERBOSE "Build verbose logging option for Server" ON)
3+
option(LLAMA_SERVER_SSL "Build SSL support for the server" OFF)
34
include_directories(${CMAKE_CURRENT_SOURCE_DIR})
45
add_executable(${TARGET} server.cpp utils.hpp json.hpp httplib.h)
56
install(TARGETS ${TARGET} RUNTIME)
67
target_compile_definitions(${TARGET} PRIVATE
78
SERVER_VERBOSE=$<BOOL:${LLAMA_SERVER_VERBOSE}>
89
)
910
target_link_libraries(${TARGET} PRIVATE common ${CMAKE_THREAD_LIBS_INIT})
11+
if (LLAMA_SERVER_SSL)
12+
find_package(OpenSSL REQUIRED)
13+
target_link_libraries(${TARGET} PRIVATE OpenSSL::SSL OpenSSL::Crypto)
14+
target_compile_definitions(${TARGET} PRIVATE CPPHTTPLIB_OPENSSL_SUPPORT)
15+
endif()
1016
if (WIN32)
1117
TARGET_LINK_LIBRARIES(${TARGET} PRIVATE ws2_32)
1218
endif()

examples/server/README.md

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,10 @@ see https://github.com/ggerganov/llama.cpp/issues/1437
5959
- `--log-disable`: Output logs to stdout only, default: enabled.
6060
- `--log-format FORMAT`: Define the log output to FORMAT: json or text (default: json)
6161

62+
**If compiled with `LLAMA_SERVER_SSL=ON`**
63+
- `--ssl-key-file FNAME`: path to file a PEM-encoded SSL private key
64+
- `--ssl-cert-file FNAME`: path to file a PEM-encoded SSL certificate
65+
6266
## Build
6367

6468
server is build alongside everything else from the root of the project
@@ -75,6 +79,28 @@ server is build alongside everything else from the root of the project
7579
cmake --build . --config Release
7680
```
7781

82+
## Build with SSL
83+
84+
server can also be built with SSL support using OpenSSL 3
85+
86+
- Using `make`:
87+
88+
```bash
89+
# NOTE: For non-system openssl, use the following:
90+
# CXXFLAGS="-I /path/to/openssl/include"
91+
# LDFLAGS="-L /path/to/openssl/lib"
92+
make LLAMA_SERVER_SSL=true server
93+
```
94+
95+
- Using `CMake`:
96+
97+
```bash
98+
mkdir build
99+
cd build
100+
cmake .. -DLLAMA_SERVER_SSL=ON
101+
make server
102+
```
103+
78104
## Quick Start
79105

80106
To get started right away, run the following command, making sure to use the correct path for the model you have:

examples/server/server.cpp

Lines changed: 74 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include <mutex>
2828
#include <thread>
2929
#include <signal.h>
30+
#include <memory>
3031

3132
using json = nlohmann::json;
3233

@@ -118,6 +119,11 @@ struct server_params {
118119

119120
std::vector<std::string> api_keys;
120121

122+
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
123+
std::string ssl_key_file = "";
124+
std::string ssl_cert_file = "";
125+
#endif
126+
121127
bool slots_endpoint = true;
122128
bool metrics_endpoint = false;
123129
};
@@ -2142,6 +2148,10 @@ static void server_print_usage(const char * argv0, const gpt_params & params, co
21422148
printf(" --path PUBLIC_PATH path from which to serve static files (default %s)\n", sparams.public_path.c_str());
21432149
printf(" --api-key API_KEY optional api key to enhance server security. If set, requests must include this key for access.\n");
21442150
printf(" --api-key-file FNAME path to file containing api keys delimited by new lines. If set, requests must include one of the keys for access.\n");
2151+
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
2152+
printf(" --ssl-key-file FNAME path to file a PEM-encoded SSL private key\n");
2153+
printf(" --ssl-cert-file FNAME path to file a PEM-encoded SSL certificate\n");
2154+
#endif
21452155
printf(" -to N, --timeout N server read/write timeout in seconds (default: %d)\n", sparams.read_timeout);
21462156
printf(" --embeddings enable embedding vector output (default: %s)\n", params.embedding ? "enabled" : "disabled");
21472157
printf(" -np N, --parallel N number of slots for process requests (default: %d)\n", params.n_parallel);
@@ -2220,7 +2230,24 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams,
22202230
}
22212231
}
22222232
key_file.close();
2223-
} else if (arg == "--timeout" || arg == "-to") {
2233+
2234+
}
2235+
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
2236+
else if (arg == "--ssl-key-file") {
2237+
if (++i >= argc) {
2238+
invalid_param = true;
2239+
break;
2240+
}
2241+
sparams.ssl_key_file = argv[i];
2242+
} else if (arg == "--ssl-cert-file") {
2243+
if (++i >= argc) {
2244+
invalid_param = true;
2245+
break;
2246+
}
2247+
sparams.ssl_cert_file = argv[i];
2248+
}
2249+
#endif
2250+
else if (arg == "--timeout" || arg == "-to") {
22242251
if (++i >= argc) {
22252252
invalid_param = true;
22262253
break;
@@ -2658,21 +2685,34 @@ int main(int argc, char ** argv) {
26582685
{"system_info", llama_print_system_info()},
26592686
});
26602687

2661-
httplib::Server svr;
2688+
std::unique_ptr<httplib::Server> svr;
2689+
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
2690+
if (sparams.ssl_key_file != "" && sparams.ssl_cert_file != "") {
2691+
LOG_INFO("Running with SSL", {{"key", sparams.ssl_key_file}, {"cert", sparams.ssl_cert_file}});
2692+
svr.reset(
2693+
new httplib::SSLServer(sparams.ssl_cert_file.c_str(), sparams.ssl_key_file.c_str())
2694+
);
2695+
} else {
2696+
LOG_INFO("Running without SSL", {});
2697+
svr.reset(new httplib::Server());
2698+
}
2699+
#else
2700+
svr.reset(new httplib::Server());
2701+
#endif
26622702

26632703
std::atomic<server_state> state{SERVER_STATE_LOADING_MODEL};
26642704

2665-
svr.set_default_headers({{"Server", "llama.cpp"}});
2705+
svr->set_default_headers({{"Server", "llama.cpp"}});
26662706

26672707
// CORS preflight
2668-
svr.Options(R"(.*)", [](const httplib::Request & req, httplib::Response & res) {
2708+
svr->Options(R"(.*)", [](const httplib::Request & req, httplib::Response & res) {
26692709
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
26702710
res.set_header("Access-Control-Allow-Credentials", "true");
26712711
res.set_header("Access-Control-Allow-Methods", "POST");
26722712
res.set_header("Access-Control-Allow-Headers", "*");
26732713
});
26742714

2675-
svr.Get("/health", [&](const httplib::Request & req, httplib::Response & res) {
2715+
svr->Get("/health", [&](const httplib::Request & req, httplib::Response & res) {
26762716
server_state current_state = state.load();
26772717
switch (current_state) {
26782718
case SERVER_STATE_READY:
@@ -2728,7 +2768,7 @@ int main(int argc, char ** argv) {
27282768
});
27292769

27302770
if (sparams.slots_endpoint) {
2731-
svr.Get("/slots", [&](const httplib::Request &, httplib::Response & res) {
2771+
svr->Get("/slots", [&](const httplib::Request &, httplib::Response & res) {
27322772
// request slots data using task queue
27332773
server_task task;
27342774
task.id = ctx_server.queue_tasks.get_new_id();
@@ -2749,7 +2789,7 @@ int main(int argc, char ** argv) {
27492789
}
27502790

27512791
if (sparams.metrics_endpoint) {
2752-
svr.Get("/metrics", [&](const httplib::Request &, httplib::Response & res) {
2792+
svr->Get("/metrics", [&](const httplib::Request &, httplib::Response & res) {
27532793
// request slots data using task queue
27542794
server_task task;
27552795
task.id = ctx_server.queue_tasks.get_new_id();
@@ -2846,9 +2886,9 @@ int main(int argc, char ** argv) {
28462886
});
28472887
}
28482888

2849-
svr.set_logger(log_server_request);
2889+
svr->set_logger(log_server_request);
28502890

2851-
svr.set_exception_handler([](const httplib::Request &, httplib::Response & res, std::exception_ptr ep) {
2891+
svr->set_exception_handler([](const httplib::Request &, httplib::Response & res, std::exception_ptr ep) {
28522892
const char fmt[] = "500 Internal Server Error\n%s";
28532893

28542894
char buf[BUFSIZ];
@@ -2864,7 +2904,7 @@ int main(int argc, char ** argv) {
28642904
res.status = 500;
28652905
});
28662906

2867-
svr.set_error_handler([](const httplib::Request &, httplib::Response & res) {
2907+
svr->set_error_handler([](const httplib::Request &, httplib::Response & res) {
28682908
if (res.status == 401) {
28692909
res.set_content("Unauthorized", "text/plain; charset=utf-8");
28702910
}
@@ -2877,16 +2917,16 @@ int main(int argc, char ** argv) {
28772917
});
28782918

28792919
// set timeouts and change hostname and port
2880-
svr.set_read_timeout (sparams.read_timeout);
2881-
svr.set_write_timeout(sparams.write_timeout);
2920+
svr->set_read_timeout (sparams.read_timeout);
2921+
svr->set_write_timeout(sparams.write_timeout);
28822922

2883-
if (!svr.bind_to_port(sparams.hostname, sparams.port)) {
2923+
if (!svr->bind_to_port(sparams.hostname, sparams.port)) {
28842924
fprintf(stderr, "\ncouldn't bind to server socket: hostname=%s port=%d\n\n", sparams.hostname.c_str(), sparams.port);
28852925
return 1;
28862926
}
28872927

28882928
// Set the base directory for serving static files
2889-
svr.set_base_dir(sparams.public_path);
2929+
svr->set_base_dir(sparams.public_path);
28902930

28912931
std::unordered_map<std::string, std::string> log_data;
28922932

@@ -2947,30 +2987,30 @@ int main(int argc, char ** argv) {
29472987
};
29482988

29492989
// this is only called if no index.html is found in the public --path
2950-
svr.Get("/", [](const httplib::Request &, httplib::Response & res) {
2990+
svr->Get("/", [](const httplib::Request &, httplib::Response & res) {
29512991
res.set_content(reinterpret_cast<const char*>(&index_html), index_html_len, "text/html; charset=utf-8");
29522992
return false;
29532993
});
29542994

29552995
// this is only called if no index.js is found in the public --path
2956-
svr.Get("/index.js", [](const httplib::Request &, httplib::Response & res) {
2996+
svr->Get("/index.js", [](const httplib::Request &, httplib::Response & res) {
29572997
res.set_content(reinterpret_cast<const char *>(&index_js), index_js_len, "text/javascript; charset=utf-8");
29582998
return false;
29592999
});
29603000

29613001
// this is only called if no index.html is found in the public --path
2962-
svr.Get("/completion.js", [](const httplib::Request &, httplib::Response & res) {
3002+
svr->Get("/completion.js", [](const httplib::Request &, httplib::Response & res) {
29633003
res.set_content(reinterpret_cast<const char*>(&completion_js), completion_js_len, "application/javascript; charset=utf-8");
29643004
return false;
29653005
});
29663006

29673007
// this is only called if no index.html is found in the public --path
2968-
svr.Get("/json-schema-to-grammar.mjs", [](const httplib::Request &, httplib::Response & res) {
3008+
svr->Get("/json-schema-to-grammar.mjs", [](const httplib::Request &, httplib::Response & res) {
29693009
res.set_content(reinterpret_cast<const char*>(&json_schema_to_grammar_mjs), json_schema_to_grammar_mjs_len, "application/javascript; charset=utf-8");
29703010
return false;
29713011
});
29723012

2973-
svr.Get("/props", [&ctx_server](const httplib::Request & req, httplib::Response & res) {
3013+
svr->Get("/props", [&ctx_server](const httplib::Request & req, httplib::Response & res) {
29743014
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
29753015
json data = {
29763016
{ "user_name", ctx_server.name_user.c_str() },
@@ -3062,11 +3102,11 @@ int main(int argc, char ** argv) {
30623102
}
30633103
};
30643104

3065-
svr.Post("/completion", completions); // legacy
3066-
svr.Post("/completions", completions);
3067-
svr.Post("/v1/completions", completions);
3105+
svr->Post("/completion", completions); // legacy
3106+
svr->Post("/completions", completions);
3107+
svr->Post("/v1/completions", completions);
30683108

3069-
svr.Get("/v1/models", [&params, &model_meta](const httplib::Request & req, httplib::Response & res) {
3109+
svr->Get("/v1/models", [&params, &model_meta](const httplib::Request & req, httplib::Response & res) {
30703110
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
30713111

30723112
json models = {
@@ -3161,10 +3201,10 @@ int main(int argc, char ** argv) {
31613201
}
31623202
};
31633203

3164-
svr.Post("/chat/completions", chat_completions);
3165-
svr.Post("/v1/chat/completions", chat_completions);
3204+
svr->Post("/chat/completions", chat_completions);
3205+
svr->Post("/v1/chat/completions", chat_completions);
31663206

3167-
svr.Post("/infill", [&ctx_server, &validate_api_key](const httplib::Request & req, httplib::Response & res) {
3207+
svr->Post("/infill", [&ctx_server, &validate_api_key](const httplib::Request & req, httplib::Response & res) {
31683208
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
31693209
if (!validate_api_key(req, res)) {
31703210
return;
@@ -3228,11 +3268,11 @@ int main(int argc, char ** argv) {
32283268
}
32293269
});
32303270

3231-
svr.Options(R"(/.*)", [](const httplib::Request &, httplib::Response & res) {
3271+
svr->Options(R"(/.*)", [](const httplib::Request &, httplib::Response & res) {
32323272
return res.set_content("", "application/json; charset=utf-8");
32333273
});
32343274

3235-
svr.Post("/tokenize", [&ctx_server](const httplib::Request & req, httplib::Response & res) {
3275+
svr->Post("/tokenize", [&ctx_server](const httplib::Request & req, httplib::Response & res) {
32363276
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
32373277
const json body = json::parse(req.body);
32383278

@@ -3244,7 +3284,7 @@ int main(int argc, char ** argv) {
32443284
return res.set_content(data.dump(), "application/json; charset=utf-8");
32453285
});
32463286

3247-
svr.Post("/detokenize", [&ctx_server](const httplib::Request & req, httplib::Response & res) {
3287+
svr->Post("/detokenize", [&ctx_server](const httplib::Request & req, httplib::Response & res) {
32483288
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
32493289
const json body = json::parse(req.body);
32503290

@@ -3258,7 +3298,7 @@ int main(int argc, char ** argv) {
32583298
return res.set_content(data.dump(), "application/json; charset=utf-8");
32593299
});
32603300

3261-
svr.Post("/embedding", [&params, &ctx_server](const httplib::Request & req, httplib::Response & res) {
3301+
svr->Post("/embedding", [&params, &ctx_server](const httplib::Request & req, httplib::Response & res) {
32623302
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
32633303
if (!params.embedding) {
32643304
res.status = 501;
@@ -3289,7 +3329,7 @@ int main(int argc, char ** argv) {
32893329
return res.set_content(result.data.dump(), "application/json; charset=utf-8");
32903330
});
32913331

3292-
svr.Post("/v1/embeddings", [&params, &ctx_server](const httplib::Request & req, httplib::Response & res) {
3332+
svr->Post("/v1/embeddings", [&params, &ctx_server](const httplib::Request & req, httplib::Response & res) {
32933333
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
32943334
if (!params.embedding) {
32953335
res.status = 501;
@@ -3360,13 +3400,13 @@ int main(int argc, char ** argv) {
33603400
sparams.n_threads_http = std::max(params.n_parallel + 2, (int32_t) std::thread::hardware_concurrency() - 1);
33613401
}
33623402
log_data["n_threads_http"] = std::to_string(sparams.n_threads_http);
3363-
svr.new_task_queue = [&sparams] { return new httplib::ThreadPool(sparams.n_threads_http); };
3403+
svr->new_task_queue = [&sparams] { return new httplib::ThreadPool(sparams.n_threads_http); };
33643404

33653405
LOG_INFO("HTTP server listening", log_data);
33663406

33673407
// run the HTTP server in a thread - see comment below
33683408
std::thread t([&]() {
3369-
if (!svr.listen_after_bind()) {
3409+
if (!svr->listen_after_bind()) {
33703410
state.store(SERVER_STATE_ERROR);
33713411
return 1;
33723412
}
@@ -3407,7 +3447,7 @@ int main(int argc, char ** argv) {
34073447

34083448
ctx_server.queue_tasks.start_loop();
34093449

3410-
svr.stop();
3450+
svr->stop();
34113451
t.join();
34123452

34133453
llama_backend_free();

0 commit comments

Comments
 (0)