27
27
#include < mutex>
28
28
#include < thread>
29
29
#include < signal.h>
30
+ #include < memory>
30
31
31
32
using json = nlohmann::json;
32
33
@@ -118,6 +119,11 @@ struct server_params {
118
119
119
120
std::vector<std::string> api_keys;
120
121
122
+ #ifdef CPPHTTPLIB_OPENSSL_SUPPORT
123
+ std::string ssl_key_file = " " ;
124
+ std::string ssl_cert_file = " " ;
125
+ #endif
126
+
121
127
bool slots_endpoint = true ;
122
128
bool metrics_endpoint = false ;
123
129
};
@@ -2142,6 +2148,10 @@ static void server_print_usage(const char * argv0, const gpt_params & params, co
2142
2148
printf (" --path PUBLIC_PATH path from which to serve static files (default %s)\n " , sparams.public_path .c_str ());
2143
2149
printf (" --api-key API_KEY optional api key to enhance server security. If set, requests must include this key for access.\n " );
2144
2150
printf (" --api-key-file FNAME path to file containing api keys delimited by new lines. If set, requests must include one of the keys for access.\n " );
2151
+ #ifdef CPPHTTPLIB_OPENSSL_SUPPORT
2152
+ printf (" --ssl-key-file FNAME path to file a PEM-encoded SSL private key\n " );
2153
+ printf (" --ssl-cert-file FNAME path to file a PEM-encoded SSL certificate\n " );
2154
+ #endif
2145
2155
printf (" -to N, --timeout N server read/write timeout in seconds (default: %d)\n " , sparams.read_timeout );
2146
2156
printf (" --embeddings enable embedding vector output (default: %s)\n " , params.embedding ? " enabled" : " disabled" );
2147
2157
printf (" -np N, --parallel N number of slots for process requests (default: %d)\n " , params.n_parallel );
@@ -2220,7 +2230,24 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams,
2220
2230
}
2221
2231
}
2222
2232
key_file.close ();
2223
- } else if (arg == " --timeout" || arg == " -to" ) {
2233
+
2234
+ }
2235
+ #ifdef CPPHTTPLIB_OPENSSL_SUPPORT
2236
+ else if (arg == " --ssl-key-file" ) {
2237
+ if (++i >= argc) {
2238
+ invalid_param = true ;
2239
+ break ;
2240
+ }
2241
+ sparams.ssl_key_file = argv[i];
2242
+ } else if (arg == " --ssl-cert-file" ) {
2243
+ if (++i >= argc) {
2244
+ invalid_param = true ;
2245
+ break ;
2246
+ }
2247
+ sparams.ssl_cert_file = argv[i];
2248
+ }
2249
+ #endif
2250
+ else if (arg == " --timeout" || arg == " -to" ) {
2224
2251
if (++i >= argc) {
2225
2252
invalid_param = true ;
2226
2253
break ;
@@ -2658,21 +2685,34 @@ int main(int argc, char ** argv) {
2658
2685
{" system_info" , llama_print_system_info ()},
2659
2686
});
2660
2687
2661
- httplib::Server svr;
2688
+ std::unique_ptr<httplib::Server> svr;
2689
+ #ifdef CPPHTTPLIB_OPENSSL_SUPPORT
2690
+ if (sparams.ssl_key_file != " " && sparams.ssl_cert_file != " " ) {
2691
+ LOG_INFO (" Running with SSL" , {{" key" , sparams.ssl_key_file }, {" cert" , sparams.ssl_cert_file }});
2692
+ svr.reset (
2693
+ new httplib::SSLServer (sparams.ssl_cert_file .c_str (), sparams.ssl_key_file .c_str ())
2694
+ );
2695
+ } else {
2696
+ LOG_INFO (" Running without SSL" , {});
2697
+ svr.reset (new httplib::Server ());
2698
+ }
2699
+ #else
2700
+ svr.reset (new httplib::Server ());
2701
+ #endif
2662
2702
2663
2703
std::atomic<server_state> state{SERVER_STATE_LOADING_MODEL};
2664
2704
2665
- svr. set_default_headers ({{" Server" , " llama.cpp" }});
2705
+ svr-> set_default_headers ({{" Server" , " llama.cpp" }});
2666
2706
2667
2707
// CORS preflight
2668
- svr. Options (R"( .*)" , [](const httplib::Request & req, httplib::Response & res) {
2708
+ svr-> Options (R"( .*)" , [](const httplib::Request & req, httplib::Response & res) {
2669
2709
res.set_header (" Access-Control-Allow-Origin" , req.get_header_value (" Origin" ));
2670
2710
res.set_header (" Access-Control-Allow-Credentials" , " true" );
2671
2711
res.set_header (" Access-Control-Allow-Methods" , " POST" );
2672
2712
res.set_header (" Access-Control-Allow-Headers" , " *" );
2673
2713
});
2674
2714
2675
- svr. Get (" /health" , [&](const httplib::Request & req, httplib::Response & res) {
2715
+ svr-> Get (" /health" , [&](const httplib::Request & req, httplib::Response & res) {
2676
2716
server_state current_state = state.load ();
2677
2717
switch (current_state) {
2678
2718
case SERVER_STATE_READY:
@@ -2728,7 +2768,7 @@ int main(int argc, char ** argv) {
2728
2768
});
2729
2769
2730
2770
if (sparams.slots_endpoint ) {
2731
- svr. Get (" /slots" , [&](const httplib::Request &, httplib::Response & res) {
2771
+ svr-> Get (" /slots" , [&](const httplib::Request &, httplib::Response & res) {
2732
2772
// request slots data using task queue
2733
2773
server_task task;
2734
2774
task.id = ctx_server.queue_tasks .get_new_id ();
@@ -2749,7 +2789,7 @@ int main(int argc, char ** argv) {
2749
2789
}
2750
2790
2751
2791
if (sparams.metrics_endpoint ) {
2752
- svr. Get (" /metrics" , [&](const httplib::Request &, httplib::Response & res) {
2792
+ svr-> Get (" /metrics" , [&](const httplib::Request &, httplib::Response & res) {
2753
2793
// request slots data using task queue
2754
2794
server_task task;
2755
2795
task.id = ctx_server.queue_tasks .get_new_id ();
@@ -2846,9 +2886,9 @@ int main(int argc, char ** argv) {
2846
2886
});
2847
2887
}
2848
2888
2849
- svr. set_logger (log_server_request);
2889
+ svr-> set_logger (log_server_request);
2850
2890
2851
- svr. set_exception_handler ([](const httplib::Request &, httplib::Response & res, std::exception_ptr ep) {
2891
+ svr-> set_exception_handler ([](const httplib::Request &, httplib::Response & res, std::exception_ptr ep) {
2852
2892
const char fmt[] = " 500 Internal Server Error\n %s" ;
2853
2893
2854
2894
char buf[BUFSIZ];
@@ -2864,7 +2904,7 @@ int main(int argc, char ** argv) {
2864
2904
res.status = 500 ;
2865
2905
});
2866
2906
2867
- svr. set_error_handler ([](const httplib::Request &, httplib::Response & res) {
2907
+ svr-> set_error_handler ([](const httplib::Request &, httplib::Response & res) {
2868
2908
if (res.status == 401 ) {
2869
2909
res.set_content (" Unauthorized" , " text/plain; charset=utf-8" );
2870
2910
}
@@ -2877,16 +2917,16 @@ int main(int argc, char ** argv) {
2877
2917
});
2878
2918
2879
2919
// set timeouts and change hostname and port
2880
- svr. set_read_timeout (sparams.read_timeout );
2881
- svr. set_write_timeout (sparams.write_timeout );
2920
+ svr-> set_read_timeout (sparams.read_timeout );
2921
+ svr-> set_write_timeout (sparams.write_timeout );
2882
2922
2883
- if (!svr. bind_to_port (sparams.hostname , sparams.port )) {
2923
+ if (!svr-> bind_to_port (sparams.hostname , sparams.port )) {
2884
2924
fprintf (stderr, " \n couldn't bind to server socket: hostname=%s port=%d\n\n " , sparams.hostname .c_str (), sparams.port );
2885
2925
return 1 ;
2886
2926
}
2887
2927
2888
2928
// Set the base directory for serving static files
2889
- svr. set_base_dir (sparams.public_path );
2929
+ svr-> set_base_dir (sparams.public_path );
2890
2930
2891
2931
std::unordered_map<std::string, std::string> log_data;
2892
2932
@@ -2947,30 +2987,30 @@ int main(int argc, char ** argv) {
2947
2987
};
2948
2988
2949
2989
// this is only called if no index.html is found in the public --path
2950
- svr. Get (" /" , [](const httplib::Request &, httplib::Response & res) {
2990
+ svr-> Get (" /" , [](const httplib::Request &, httplib::Response & res) {
2951
2991
res.set_content (reinterpret_cast <const char *>(&index_html), index_html_len, " text/html; charset=utf-8" );
2952
2992
return false ;
2953
2993
});
2954
2994
2955
2995
// this is only called if no index.js is found in the public --path
2956
- svr. Get (" /index.js" , [](const httplib::Request &, httplib::Response & res) {
2996
+ svr-> Get (" /index.js" , [](const httplib::Request &, httplib::Response & res) {
2957
2997
res.set_content (reinterpret_cast <const char *>(&index_js), index_js_len, " text/javascript; charset=utf-8" );
2958
2998
return false ;
2959
2999
});
2960
3000
2961
3001
// this is only called if no index.html is found in the public --path
2962
- svr. Get (" /completion.js" , [](const httplib::Request &, httplib::Response & res) {
3002
+ svr-> Get (" /completion.js" , [](const httplib::Request &, httplib::Response & res) {
2963
3003
res.set_content (reinterpret_cast <const char *>(&completion_js), completion_js_len, " application/javascript; charset=utf-8" );
2964
3004
return false ;
2965
3005
});
2966
3006
2967
3007
// this is only called if no index.html is found in the public --path
2968
- svr. Get (" /json-schema-to-grammar.mjs" , [](const httplib::Request &, httplib::Response & res) {
3008
+ svr-> Get (" /json-schema-to-grammar.mjs" , [](const httplib::Request &, httplib::Response & res) {
2969
3009
res.set_content (reinterpret_cast <const char *>(&json_schema_to_grammar_mjs), json_schema_to_grammar_mjs_len, " application/javascript; charset=utf-8" );
2970
3010
return false ;
2971
3011
});
2972
3012
2973
- svr. Get (" /props" , [&ctx_server](const httplib::Request & req, httplib::Response & res) {
3013
+ svr-> Get (" /props" , [&ctx_server](const httplib::Request & req, httplib::Response & res) {
2974
3014
res.set_header (" Access-Control-Allow-Origin" , req.get_header_value (" Origin" ));
2975
3015
json data = {
2976
3016
{ " user_name" , ctx_server.name_user .c_str () },
@@ -3062,11 +3102,11 @@ int main(int argc, char ** argv) {
3062
3102
}
3063
3103
};
3064
3104
3065
- svr. Post (" /completion" , completions); // legacy
3066
- svr. Post (" /completions" , completions);
3067
- svr. Post (" /v1/completions" , completions);
3105
+ svr-> Post (" /completion" , completions); // legacy
3106
+ svr-> Post (" /completions" , completions);
3107
+ svr-> Post (" /v1/completions" , completions);
3068
3108
3069
- svr. Get (" /v1/models" , [¶ms, &model_meta](const httplib::Request & req, httplib::Response & res) {
3109
+ svr-> Get (" /v1/models" , [¶ms, &model_meta](const httplib::Request & req, httplib::Response & res) {
3070
3110
res.set_header (" Access-Control-Allow-Origin" , req.get_header_value (" Origin" ));
3071
3111
3072
3112
json models = {
@@ -3161,10 +3201,10 @@ int main(int argc, char ** argv) {
3161
3201
}
3162
3202
};
3163
3203
3164
- svr. Post (" /chat/completions" , chat_completions);
3165
- svr. Post (" /v1/chat/completions" , chat_completions);
3204
+ svr-> Post (" /chat/completions" , chat_completions);
3205
+ svr-> Post (" /v1/chat/completions" , chat_completions);
3166
3206
3167
- svr. Post (" /infill" , [&ctx_server, &validate_api_key](const httplib::Request & req, httplib::Response & res) {
3207
+ svr-> Post (" /infill" , [&ctx_server, &validate_api_key](const httplib::Request & req, httplib::Response & res) {
3168
3208
res.set_header (" Access-Control-Allow-Origin" , req.get_header_value (" Origin" ));
3169
3209
if (!validate_api_key (req, res)) {
3170
3210
return ;
@@ -3228,11 +3268,11 @@ int main(int argc, char ** argv) {
3228
3268
}
3229
3269
});
3230
3270
3231
- svr. Options (R"( /.*)" , [](const httplib::Request &, httplib::Response & res) {
3271
+ svr-> Options (R"( /.*)" , [](const httplib::Request &, httplib::Response & res) {
3232
3272
return res.set_content (" " , " application/json; charset=utf-8" );
3233
3273
});
3234
3274
3235
- svr. Post (" /tokenize" , [&ctx_server](const httplib::Request & req, httplib::Response & res) {
3275
+ svr-> Post (" /tokenize" , [&ctx_server](const httplib::Request & req, httplib::Response & res) {
3236
3276
res.set_header (" Access-Control-Allow-Origin" , req.get_header_value (" Origin" ));
3237
3277
const json body = json::parse (req.body );
3238
3278
@@ -3244,7 +3284,7 @@ int main(int argc, char ** argv) {
3244
3284
return res.set_content (data.dump (), " application/json; charset=utf-8" );
3245
3285
});
3246
3286
3247
- svr. Post (" /detokenize" , [&ctx_server](const httplib::Request & req, httplib::Response & res) {
3287
+ svr-> Post (" /detokenize" , [&ctx_server](const httplib::Request & req, httplib::Response & res) {
3248
3288
res.set_header (" Access-Control-Allow-Origin" , req.get_header_value (" Origin" ));
3249
3289
const json body = json::parse (req.body );
3250
3290
@@ -3258,7 +3298,7 @@ int main(int argc, char ** argv) {
3258
3298
return res.set_content (data.dump (), " application/json; charset=utf-8" );
3259
3299
});
3260
3300
3261
- svr. Post (" /embedding" , [¶ms, &ctx_server](const httplib::Request & req, httplib::Response & res) {
3301
+ svr-> Post (" /embedding" , [¶ms, &ctx_server](const httplib::Request & req, httplib::Response & res) {
3262
3302
res.set_header (" Access-Control-Allow-Origin" , req.get_header_value (" Origin" ));
3263
3303
if (!params.embedding ) {
3264
3304
res.status = 501 ;
@@ -3289,7 +3329,7 @@ int main(int argc, char ** argv) {
3289
3329
return res.set_content (result.data .dump (), " application/json; charset=utf-8" );
3290
3330
});
3291
3331
3292
- svr. Post (" /v1/embeddings" , [¶ms, &ctx_server](const httplib::Request & req, httplib::Response & res) {
3332
+ svr-> Post (" /v1/embeddings" , [¶ms, &ctx_server](const httplib::Request & req, httplib::Response & res) {
3293
3333
res.set_header (" Access-Control-Allow-Origin" , req.get_header_value (" Origin" ));
3294
3334
if (!params.embedding ) {
3295
3335
res.status = 501 ;
@@ -3360,13 +3400,13 @@ int main(int argc, char ** argv) {
3360
3400
sparams.n_threads_http = std::max (params.n_parallel + 2 , (int32_t ) std::thread::hardware_concurrency () - 1 );
3361
3401
}
3362
3402
log_data[" n_threads_http" ] = std::to_string (sparams.n_threads_http );
3363
- svr. new_task_queue = [&sparams] { return new httplib::ThreadPool (sparams.n_threads_http ); };
3403
+ svr-> new_task_queue = [&sparams] { return new httplib::ThreadPool (sparams.n_threads_http ); };
3364
3404
3365
3405
LOG_INFO (" HTTP server listening" , log_data);
3366
3406
3367
3407
// run the HTTP server in a thread - see comment below
3368
3408
std::thread t ([&]() {
3369
- if (!svr. listen_after_bind ()) {
3409
+ if (!svr-> listen_after_bind ()) {
3370
3410
state.store (SERVER_STATE_ERROR);
3371
3411
return 1 ;
3372
3412
}
@@ -3407,7 +3447,7 @@ int main(int argc, char ** argv) {
3407
3447
3408
3448
ctx_server.queue_tasks .start_loop ();
3409
3449
3410
- svr. stop ();
3450
+ svr-> stop ();
3411
3451
t.join ();
3412
3452
3413
3453
llama_backend_free ();
0 commit comments