Skip to content

Commit 337aea1

Browse files
examples : add --alias option to gpt_params to set use friendly model name (#1614)
1 parent bb051d9 commit 337aea1

File tree

3 files changed

+21
-2
lines changed

3 files changed

+21
-2
lines changed

examples/common.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
251251
break;
252252
}
253253
params.model = argv[i];
254+
} else if (arg == "-a" || arg == "--alias") {
255+
if (++i >= argc) {
256+
invalid_param = true;
257+
break;
258+
}
259+
params.model_alias = argv[i];
254260
} else if (arg == "--lora") {
255261
if (++i >= argc) {
256262
invalid_param = true;

examples/common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ struct gpt_params {
4545
float mirostat_eta = 0.10f; // learning rate
4646

4747
std::string model = "models/7B/ggml-model.bin"; // model path
48+
std::string model_alias = "unknown"; // model alias
4849
std::string prompt = "";
4950
std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state
5051
std::string input_prefix = ""; // string to prefix user inputs with

examples/server/server.cpp

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -400,8 +400,10 @@ void server_print_usage(int /*argc*/, char **argv, const gpt_params &params)
400400
fprintf(stderr, " number of layers to store in VRAM\n");
401401
fprintf(stderr, " -m FNAME, --model FNAME\n");
402402
fprintf(stderr, " model path (default: %s)\n", params.model.c_str());
403-
fprintf(stderr, " -host ip address to listen (default 127.0.0.1)\n");
404-
fprintf(stderr, " -port PORT port to listen (default 8080)\n");
403+
fprintf(stderr, " -a ALIAS, --alias ALIAS\n");
404+
fprintf(stderr, " set an alias for the model, will be added as `model` field in completion response\n");
405+
fprintf(stderr, " --host ip address to listen (default 127.0.0.1)\n");
406+
fprintf(stderr, " --port PORT port to listen (default 8080)\n");
405407
fprintf(stderr, "\n");
406408
}
407409

@@ -453,6 +455,15 @@ bool server_params_parse(int argc, char **argv, server_params &sparams, gpt_para
453455
}
454456
params.model = argv[i];
455457
}
458+
else if (arg == "-a" || arg == "--alias")
459+
{
460+
if (++i >= argc)
461+
{
462+
invalid_param = true;
463+
break;
464+
}
465+
params.model_alias = argv[i];
466+
}
456467
else if (arg == "--embedding")
457468
{
458469
params.embedding = true;
@@ -645,6 +656,7 @@ int main(int argc, char **argv)
645656
try
646657
{
647658
json data = {
659+
{"model", llama.params.model_alias },
648660
{"content", llama.generated_text },
649661
{"tokens_predicted", llama.num_tokens_predicted}};
650662
return res.set_content(data.dump(), "application/json");

0 commit comments

Comments
 (0)