Skip to content

Commit 8be551b

Browse files
ngxsontybalex
authored andcommitted
Server: clean up OAI params parsing function (ggml-org#6284)
* server: clean up oai parsing function * fix response_format * fix empty response_format * minor fixes * add TODO for logprobs * update docs
1 parent 247cc1c commit 8be551b

File tree

3 files changed

+61
-37
lines changed

3 files changed

+61
-37
lines changed

examples/server/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,7 @@ Notice that each `probs` is an array of length `n_probs`.
360360
- `default_generation_settings` - the default generation settings for the `/completion` endpoint, has the same fields as the `generation_settings` response object from the `/completion` endpoint.
361361
- `total_slots` - the total number of slots for process requests (defined by `--parallel` option)
362362

363-
- **POST** `/v1/chat/completions`: OpenAI-compatible Chat Completions API. Given a ChatML-formatted json description in `messages`, it returns the predicted completion. Both synchronous and streaming mode are supported, so scripted and interactive applications work fine. While no strong claims of compatibility with OpenAI API spec is being made, in our experience it suffices to support many apps. Only ChatML-tuned models, such as Dolphin, OpenOrca, OpenHermes, OpenChat-3.5, etc can be used with this endpoint.
363+
- **POST** `/v1/chat/completions`: OpenAI-compatible Chat Completions API. Given a ChatML-formatted json description in `messages`, it returns the predicted completion. Both synchronous and streaming mode are supported, so scripted and interactive applications work fine. While no strong claims of compatibility with OpenAI API spec is being made, in our experience it suffices to support many apps. Only model with [supported chat template](https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template) can be used optimally with this endpoint. By default, ChatML template will be used.
364364

365365
*Options:*
366366

examples/server/server.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -847,9 +847,16 @@ struct server_context {
847847
slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl);
848848
slot.params.n_keep = json_value(data, "n_keep", slot.params.n_keep);
849849
slot.params.seed = json_value(data, "seed", default_params.seed);
850-
if (data.contains("json_schema") && !data.contains("grammar")) {
850+
slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
851+
slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);
852+
853+
// process "json_schema" and "grammar"
854+
if (data.contains("json_schema") && data.contains("grammar")) {
855+
send_error(task, "Either \"json_schema\" or \"grammar\" can be specified, but not both", ERROR_TYPE_INVALID_REQUEST);
856+
return false;
857+
} else if (data.contains("json_schema") && !data.contains("grammar")) {
851858
try {
852-
auto schema = json_value(data, "json_schema", json::object());
859+
auto schema = json_value(data, "json_schema", json::object());
853860
slot.sparams.grammar = json_schema_to_grammar(schema);
854861
} catch (const std::exception & e) {
855862
send_error(task, std::string("\"json_schema\": ") + e.what(), ERROR_TYPE_INVALID_REQUEST);
@@ -858,8 +865,6 @@ struct server_context {
858865
} else {
859866
slot.sparams.grammar = json_value(data, "grammar", default_sparams.grammar);
860867
}
861-
slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
862-
slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);
863868

864869
if (slot.params.cache_prompt && slot.ga_n != 1) {
865870
LOG_WARNING("cache_prompt is not supported with group-attention", {});

examples/server/utils.hpp

Lines changed: 51 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -725,50 +725,69 @@ static json oaicompat_completion_params_parse(
725725
// https://platform.openai.com/docs/api-reference/chat/create
726726
llama_sampling_params default_sparams;
727727
llama_params["model"] = json_value(body, "model", std::string("unknown"));
728-
llama_params["cache_prompt"] = json_value(body, "cache_prompt", false);
729-
llama_params["temperature"] = json_value(body, "temperature", 0.0);
730-
llama_params["top_k"] = json_value(body, "top_k", default_sparams.top_k);
731-
llama_params["top_p"] = json_value(body, "top_p", 1.0);
732-
llama_params["n_predict"] = json_value(body, "max_tokens", -1);
733-
llama_params["logit_bias"] = json_value(body, "logit_bias", json::object());
734728
llama_params["frequency_penalty"] = json_value(body, "frequency_penalty", 0.0);
729+
llama_params["logit_bias"] = json_value(body, "logit_bias", json::object());
730+
llama_params["n_predict"] = json_value(body, "max_tokens", -1);
735731
llama_params["presence_penalty"] = json_value(body, "presence_penalty", 0.0);
736732
llama_params["seed"] = json_value(body, "seed", LLAMA_DEFAULT_SEED);
737733
llama_params["stream"] = json_value(body, "stream", false);
738-
llama_params["mirostat"] = json_value(body, "mirostat", default_sparams.mirostat);
739-
llama_params["mirostat_tau"] = json_value(body, "mirostat_tau", default_sparams.mirostat_tau);
740-
llama_params["mirostat_eta"] = json_value(body, "mirostat_eta", default_sparams.mirostat_eta);
741-
llama_params["penalize_nl"] = json_value(body, "penalize_nl", default_sparams.penalize_nl);
742-
llama_params["typical_p"] = json_value(body, "typical_p", default_sparams.typical_p);
743-
llama_params["repeat_last_n"] = json_value(body, "repeat_last_n", default_sparams.penalty_last_n);
744-
llama_params["ignore_eos"] = json_value(body, "ignore_eos", false);
745-
llama_params["tfs_z"] = json_value(body, "tfs_z", default_sparams.tfs_z);
746-
llama_params["n_keep"] = json_value(body, "n_keep", 0);
747-
748-
if (body.contains("grammar")) {
749-
llama_params["grammar"] = json_value(body, "grammar", json::object());
750-
}
734+
llama_params["temperature"] = json_value(body, "temperature", 0.0);
735+
llama_params["top_p"] = json_value(body, "top_p", 1.0);
751736

752-
if (body.contains("response_format")) {
753-
auto response_format = json_value(body, "response_format", json::object());
754-
if (response_format.contains("type")) {
755-
if (response_format["type"] == "json_object") {
756-
llama_params["json_schema"] = json_value(response_format, "schema", json::object());
757-
} else {
758-
throw std::runtime_error("response_format type not supported: " + response_format["type"].dump());
759-
}
760-
}
761-
}
762737

763-
// Handle 'stop' field
738+
// Handle "stop" field
764739
if (body.contains("stop") && body["stop"].is_string()) {
765740
llama_params["stop"] = json::array({body["stop"].get<std::string>()});
766741
} else {
767742
llama_params["stop"] = json_value(body, "stop", json::array());
768743
}
744+
// Some chat templates don't use EOS token to stop generation
745+
// We must add their end sequences to list of stop words
746+
llama_params["stop"].push_back("<|im_end|>"); // chatml
747+
llama_params["stop"].push_back("<end_of_turn>"); // gemma
748+
749+
// Handle "response_format" field
750+
if (body.contains("response_format")) {
751+
json response_format = json_value(body, "response_format", json::object());
752+
std::string response_type = json_value(response_format, "type", std::string());
753+
if (response_type == "json_object") {
754+
llama_params["json_schema"] = json_value(response_format, "schema", json::object());
755+
} else if (!response_type.empty() && response_type != "text") {
756+
throw std::runtime_error("response_format type must be one of \"text\" or \"json_object\", but got: " + response_type);
757+
}
758+
}
769759

770-
// Ensure there is ChatML-specific end sequence among stop words
771-
llama_params["stop"].push_back("<|im_end|>");
760+
// Handle "n" field
761+
int n_choices = json_value(body, "n", 1);
762+
if (n_choices != 1) {
763+
throw std::runtime_error("Only one completion choice is allowed");
764+
}
765+
766+
// Handle "logprobs" field
767+
// TODO: The response format of this option is not yet OAI-compatible, but seems like no one really using it; We may need to fix it in the future
768+
if (body.contains("logprobs")) {
769+
llama_params["n_probs"] = json_value(body, "top_logprobs", 20);
770+
} else if (body.contains("top_logprobs")) {
771+
throw std::runtime_error("top_logprobs requires logprobs to be set to true");
772+
}
773+
774+
// Params supported by OAI but unsupported by llama.cpp
775+
// static const std::vector<std::string> unsupported_params { "tools", "tool_choice" };
776+
// for (auto & param : unsupported_params) {
777+
// if (body.contains(param)) {
778+
// throw std::runtime_error("Unsupported param: " + param);
779+
// }
780+
// }
781+
782+
// Copy remaining properties to llama_params
783+
// This allows user to use llama.cpp-specific params like "mirostat", "tfs_z",... via OAI endpoint.
784+
// See "launch_slot_with_task()" for a complete list of params supported by llama.cpp
785+
for (const auto & item : body.items()) {
786+
// Exception: if "n_predict" is present, we overwrite the value specified earlier by "max_tokens"
787+
if (!llama_params.contains(item.key()) || item.key() == "n_predict") {
788+
llama_params[item.key()] = item.value();
789+
}
790+
}
772791

773792
return llama_params;
774793
}

0 commit comments

Comments
 (0)