Skip to content

Commit 3e435cb

Browse files
CLI args use - instead of _, backwards compatible
1 parent 089b1c9 commit 3e435cb

File tree

1 file changed

+34
-26
lines changed

1 file changed

+34
-26
lines changed

examples/common.cpp

Lines changed: 34 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,17 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
9191
bool escape_prompt = false;
9292
std::string arg;
9393
gpt_params default_params;
94+
const std::string arg_prefix = "--";
9495

9596
for (int i = 1; i < argc; i++) {
9697
arg = argv[i];
98+
if (arg.compare(0, arg_prefix.size(), arg_prefix) == 0) {
99+
size_t arg_underscore_index = arg.find("_");
100+
while (arg_underscore_index != std::string::npos) {
101+
arg = arg.replace(arg_underscore_index, sizeof("_") - 1, "-");
102+
arg_underscore_index = arg.find("_");
103+
}
104+
}
97105

98106
if (arg == "-s" || arg == "--seed") {
99107
#if defined(GGML_USE_CUBLAS)
@@ -141,27 +149,27 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
141149
if (params.prompt.back() == '\n') {
142150
params.prompt.pop_back();
143151
}
144-
} else if (arg == "-n" || arg == "--n_predict") {
152+
} else if (arg == "-n" || arg == "--n-predict") {
145153
if (++i >= argc) {
146154
invalid_param = true;
147155
break;
148156
}
149157
params.n_predict = std::stoi(argv[i]);
150-
} else if (arg == "--top_k") {
158+
} else if (arg == "--top-k") {
151159
if (++i >= argc) {
152160
invalid_param = true;
153161
break;
154162
}
155163
params.top_k = std::stoi(argv[i]);
156-
} else if (arg == "-c" || arg == "--ctx_size") {
164+
} else if (arg == "-c" || arg == "--ctx-size") {
157165
if (++i >= argc) {
158166
invalid_param = true;
159167
break;
160168
}
161169
params.n_ctx = std::stoi(argv[i]);
162-
} else if (arg == "--memory_f32") {
170+
} else if (arg == "--memory-f32") {
163171
params.memory_f16 = false;
164-
} else if (arg == "--top_p") {
172+
} else if (arg == "--top-p") {
165173
if (++i >= argc) {
166174
invalid_param = true;
167175
break;
@@ -185,25 +193,25 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
185193
break;
186194
}
187195
params.typical_p = std::stof(argv[i]);
188-
} else if (arg == "--repeat_last_n") {
196+
} else if (arg == "--repeat-last-n") {
189197
if (++i >= argc) {
190198
invalid_param = true;
191199
break;
192200
}
193201
params.repeat_last_n = std::stoi(argv[i]);
194-
} else if (arg == "--repeat_penalty") {
202+
} else if (arg == "--repeat-penalty") {
195203
if (++i >= argc) {
196204
invalid_param = true;
197205
break;
198206
}
199207
params.repeat_penalty = std::stof(argv[i]);
200-
} else if (arg == "--frequency_penalty") {
208+
} else if (arg == "--frequency-penalty") {
201209
if (++i >= argc) {
202210
invalid_param = true;
203211
break;
204212
}
205213
params.frequency_penalty = std::stof(argv[i]);
206-
} else if (arg == "--presence_penalty") {
214+
} else if (arg == "--presence-penalty") {
207215
if (++i >= argc) {
208216
invalid_param = true;
209217
break;
@@ -215,19 +223,19 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
215223
break;
216224
}
217225
params.mirostat = std::stoi(argv[i]);
218-
} else if (arg == "--mirostat_lr") {
226+
} else if (arg == "--mirostat-lr") {
219227
if (++i >= argc) {
220228
invalid_param = true;
221229
break;
222230
}
223231
params.mirostat_eta = std::stof(argv[i]);
224-
} else if (arg == "--mirostat_ent") {
232+
} else if (arg == "--mirostat-ent") {
225233
if (++i >= argc) {
226234
invalid_param = true;
227235
break;
228236
}
229237
params.mirostat_tau = std::stof(argv[i]);
230-
} else if (arg == "-b" || arg == "--batch_size") {
238+
} else if (arg == "-b" || arg == "--batch-size") {
231239
if (++i >= argc) {
232240
invalid_param = true;
233241
break;
@@ -310,7 +318,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
310318
invalid_param = true;
311319
break;
312320
}
313-
} else if (arg == "--n_parts") {
321+
} else if (arg == "--n-parts") {
314322
if (++i >= argc) {
315323
invalid_param = true;
316324
break;
@@ -384,31 +392,31 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
384392
fprintf(stderr, " --in-suffix STRING string to suffix after user inputs with (default: empty)\n");
385393
fprintf(stderr, " -f FNAME, --file FNAME\n");
386394
fprintf(stderr, " prompt file to start generation.\n");
387-
fprintf(stderr, " -n N, --n_predict N number of tokens to predict (default: %d, -1 = infinity)\n", params.n_predict);
388-
fprintf(stderr, " --top_k N top-k sampling (default: %d, 0 = disabled)\n", params.top_k);
389-
fprintf(stderr, " --top_p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)params.top_p);
395+
fprintf(stderr, " -n N, --n-predict N number of tokens to predict (default: %d, -1 = infinity)\n", params.n_predict);
396+
fprintf(stderr, " --top-k N top-k sampling (default: %d, 0 = disabled)\n", params.top_k);
397+
fprintf(stderr, " --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)params.top_p);
390398
fprintf(stderr, " --tfs N tail free sampling, parameter z (default: %.1f, 1.0 = disabled)\n", (double)params.tfs_z);
391399
fprintf(stderr, " --typical N locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)\n", (double)params.typical_p);
392-
fprintf(stderr, " --repeat_last_n N last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)\n", params.repeat_last_n);
393-
fprintf(stderr, " --repeat_penalty N penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)\n", (double)params.repeat_penalty);
394-
fprintf(stderr, " --presence_penalty N repeat alpha presence penalty (default: %.1f, 0.0 = disabled)\n", (double)params.presence_penalty);
395-
fprintf(stderr, " --frequency_penalty N repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)\n", (double)params.frequency_penalty);
400+
fprintf(stderr, " --repeat-last-n N last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)\n", params.repeat_last_n);
401+
fprintf(stderr, " --repeat-penalty N penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)\n", (double)params.repeat_penalty);
402+
fprintf(stderr, " --presence-penalty N repeat alpha presence penalty (default: %.1f, 0.0 = disabled)\n", (double)params.presence_penalty);
403+
fprintf(stderr, " --frequency-penalty N repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)\n", (double)params.frequency_penalty);
396404
fprintf(stderr, " --mirostat N use Mirostat sampling.\n");
397405
fprintf(stderr, " Top K, Nucleus, Tail Free and Locally Typical samplers are ignored if used.\n");
398406
fprintf(stderr, " (default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)\n", params.mirostat);
399-
fprintf(stderr, " --mirostat_lr N Mirostat learning rate, parameter eta (default: %.1f)\n", (double)params.mirostat_eta);
400-
fprintf(stderr, " --mirostat_ent N Mirostat target entropy, parameter tau (default: %.1f)\n", (double)params.mirostat_tau);
407+
fprintf(stderr, " --mirostat-lr N Mirostat learning rate, parameter eta (default: %.1f)\n", (double)params.mirostat_eta);
408+
fprintf(stderr, " --mirostat-ent N Mirostat target entropy, parameter tau (default: %.1f)\n", (double)params.mirostat_tau);
401409
fprintf(stderr, " -l TOKEN_ID(+/-)BIAS, --logit-bias TOKEN_ID(+/-)BIAS\n");
402410
fprintf(stderr, " modifies the likelihood of token appearing in the completion,\n");
403411
fprintf(stderr, " i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',\n");
404412
fprintf(stderr, " or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'\n");
405-
fprintf(stderr, " -c N, --ctx_size N size of the prompt context (default: %d)\n", params.n_ctx);
413+
fprintf(stderr, " -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx);
406414
fprintf(stderr, " --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n");
407415
fprintf(stderr, " --no-penalize-nl do not penalize newline token\n");
408-
fprintf(stderr, " --memory_f32 use f32 instead of f16 for memory key+value\n");
416+
fprintf(stderr, " --memory-f32 use f32 instead of f16 for memory key+value\n");
409417
fprintf(stderr, " --temp N temperature (default: %.1f)\n", (double)params.temp);
410-
fprintf(stderr, " --n_parts N number of model parts (default: -1 = determine from dimensions)\n");
411-
fprintf(stderr, " -b N, --batch_size N batch size for prompt processing (default: %d)\n", params.n_batch);
418+
fprintf(stderr, " --n-parts N number of model parts (default: -1 = determine from dimensions)\n");
419+
fprintf(stderr, " -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
412420
fprintf(stderr, " --perplexity compute perplexity over the prompt\n");
413421
fprintf(stderr, " --keep number of tokens to keep from the initial prompt (default: %d, -1 = all)\n", params.n_keep);
414422
if (llama_mlock_supported()) {

0 commit comments

Comments
 (0)