Skip to content

common : bring back missing args, add env var duplication check #9375

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Sep 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 77 additions & 36 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -673,17 +673,8 @@ std::vector<llama_arg> gpt_params_parser_init(gpt_params & params, llama_example
* - if LLAMA_EXAMPLE_* is set (other than COMMON), we only show the option in the corresponding example
* - if both {LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_*,} are set, we will prioritize the LLAMA_EXAMPLE_* matching current example
*/
std::unordered_set<std::string> seen_args;
auto add_opt = [&](llama_arg arg) {
if (arg.in_example(ex) || arg.in_example(LLAMA_EXAMPLE_COMMON)) {
// make sure there is no argument duplications
for (const auto & a : arg.args) {
if (seen_args.find(a) == seen_args.end()) {
seen_args.insert(a);
} else {
throw std::runtime_error(format("found duplicated argument in source code: %s", a));
}
}
options.push_back(std::move(arg));
}
};
Expand Down Expand Up @@ -790,8 +781,7 @@ std::vector<llama_arg> gpt_params_parser_init(gpt_params & params, llama_example
add_opt(llama_arg(
{"-C", "--cpu-mask"}, "M",
"CPU affinity mask: arbitrarily long hex. Complements cpu-range (default: \"\")",
[](gpt_params & params, const std::string & value) {
std::string mask = value;
[](gpt_params & params, const std::string & mask) {
params.cpuparams.mask_valid = true;
if (!parse_cpu_mask(mask, params.cpuparams.cpumask)) {
throw std::invalid_argument("invalid cpumask");
Expand All @@ -801,8 +791,7 @@ std::vector<llama_arg> gpt_params_parser_init(gpt_params & params, llama_example
add_opt(llama_arg(
{"-Cr", "--cpu-range"}, "lo-hi",
"range of CPUs for affinity. Complements --cpu-mask",
[](gpt_params & params, const std::string & value) {
std::string range = value;
[](gpt_params & params, const std::string & range) {
params.cpuparams.mask_valid = true;
if (!parse_cpu_range(range, params.cpuparams.cpumask)) {
throw std::invalid_argument("invalid range");
Expand All @@ -816,6 +805,16 @@ std::vector<llama_arg> gpt_params_parser_init(gpt_params & params, llama_example
params.cpuparams.strict_cpu = std::stoul(value);
}
));
add_opt(llama_arg(
{"--prio"}, "N",
format("set process/thread priority : 0-normal, 1-medium, 2-high, 3-realtime (default: %d)\n", params.cpuparams.priority),
[](gpt_params & params, int prio) {
if (prio < 0 || prio > 3) {
throw std::invalid_argument("invalid value");
}
params.cpuparams.priority = (enum ggml_sched_priority) prio;
}
));
add_opt(llama_arg(
{"--poll"}, "<0...100>",
format("use polling level to wait for work (0 - no polling, default: %u)\n", (unsigned) params.cpuparams.poll),
Expand All @@ -826,8 +825,7 @@ std::vector<llama_arg> gpt_params_parser_init(gpt_params & params, llama_example
add_opt(llama_arg(
{"-Cb", "--cpu-mask-batch"}, "M",
"CPU affinity mask: arbitrarily long hex. Complements cpu-range-batch (default: same as --cpu-mask)",
[](gpt_params & params, const std::string & value) {
std::string mask = value;
[](gpt_params & params, const std::string & mask) {
params.cpuparams_batch.mask_valid = true;
if (!parse_cpu_mask(mask, params.cpuparams_batch.cpumask)) {
throw std::invalid_argument("invalid cpumask");
Expand All @@ -837,8 +835,7 @@ std::vector<llama_arg> gpt_params_parser_init(gpt_params & params, llama_example
add_opt(llama_arg(
{"-Crb", "--cpu-range-batch"}, "lo-hi",
"ranges of CPUs for affinity. Complements --cpu-mask-batch",
[](gpt_params & params, const std::string & value) {
std::string range = value;
[](gpt_params & params, const std::string & range) {
params.cpuparams_batch.mask_valid = true;
if (!parse_cpu_range(range, params.cpuparams_batch.cpumask)) {
throw std::invalid_argument("invalid range");
Expand All @@ -852,6 +849,16 @@ std::vector<llama_arg> gpt_params_parser_init(gpt_params & params, llama_example
params.cpuparams_batch.strict_cpu = value;
}
));
add_opt(llama_arg(
{"--prio-batch"}, "N",
format("set process/thread priority : 0-normal, 1-medium, 2-high, 3-realtime (default: %d)\n", params.cpuparams_batch.priority),
[](gpt_params & params, int prio) {
if (prio < 0 || prio > 3) {
throw std::invalid_argument("invalid value");
}
params.cpuparams_batch.priority = (enum ggml_sched_priority) prio;
}
));
add_opt(llama_arg(
{"--poll-batch"}, "<0|1>",
"use polling to wait for work (default: same as --poll)",
Expand All @@ -862,8 +869,7 @@ std::vector<llama_arg> gpt_params_parser_init(gpt_params & params, llama_example
add_opt(llama_arg(
{"-Cd", "--cpu-mask-draft"}, "M",
"Draft model CPU affinity mask. Complements cpu-range-draft (default: same as --cpu-mask)",
[](gpt_params & params, const std::string & value) {
std::string mask = value;
[](gpt_params & params, const std::string & mask) {
params.draft_cpuparams.mask_valid = true;
if (!parse_cpu_mask(mask, params.draft_cpuparams.cpumask)) {
throw std::invalid_argument("invalid cpumask");
Expand All @@ -873,8 +879,7 @@ std::vector<llama_arg> gpt_params_parser_init(gpt_params & params, llama_example
add_opt(llama_arg(
{"-Crd", "--cpu-range-draft"}, "lo-hi",
"Ranges of CPUs for affinity. Complements --cpu-mask-draft",
[](gpt_params & params, const std::string & value) {
std::string range = value;
[](gpt_params & params, const std::string & range) {
params.draft_cpuparams.mask_valid = true;
if (!parse_cpu_range(range, params.draft_cpuparams.cpumask)) {
throw std::invalid_argument("invalid range");
Expand All @@ -888,18 +893,37 @@ std::vector<llama_arg> gpt_params_parser_init(gpt_params & params, llama_example
params.draft_cpuparams.strict_cpu = value;
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
add_opt(llama_arg(
{"--prio-draft"}, "N",
format("set draft process/thread priority : 0-normal, 1-medium, 2-high, 3-realtime (default: %d)\n", params.draft_cpuparams.priority),
[](gpt_params & params, int prio) {
if (prio < 0 || prio > 3) {
throw std::invalid_argument("invalid value");
}
params.draft_cpuparams.priority = (enum ggml_sched_priority) prio;
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
add_opt(llama_arg(
{"--poll-draft"}, "<0|1>",
"Use polling to wait for draft model work (default: same as --poll])",
[](gpt_params & params, int value) {
params.draft_cpuparams.poll = value;
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
add_opt(llama_arg(
{"-Cbd", "--cpu-mask-batch-draft"}, "M",
"Draft model CPU affinity mask. Complements cpu-range-draft (default: same as --cpu-mask)",
[](gpt_params & params, const std::string & mask) {
params.draft_cpuparams_batch.mask_valid = true;
if (!parse_cpu_mask(mask, params.draft_cpuparams_batch.cpumask)) {
throw std::invalid_argument("invalid cpumask");
}
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
add_opt(llama_arg(
{"-Crbd", "--cpu-range-batch-draft"}, "lo-hi",
"Ranges of CPUs for affinity. Complements --cpu-mask-draft-batch)",
[](gpt_params & params, const std::string & value) {
std::string range = value;
[](gpt_params & params, const std::string & range) {
params.draft_cpuparams_batch.mask_valid = true;
if (!parse_cpu_range(range, params.draft_cpuparams_batch.cpumask)) {
throw std::invalid_argument("invalid cpumask");
Expand All @@ -913,6 +937,16 @@ std::vector<llama_arg> gpt_params_parser_init(gpt_params & params, llama_example
params.draft_cpuparams_batch.strict_cpu = value;
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
add_opt(llama_arg(
{"--prio-batch-draft"}, "N",
format("set draft process/thread priority : 0-normal, 1-medium, 2-high, 3-realtime (default: %d)\n", params.draft_cpuparams_batch.priority),
[](gpt_params & params, int prio) {
if (prio < 0 || prio > 3) {
throw std::invalid_argument("invalid value");
}
params.draft_cpuparams_batch.priority = (enum ggml_sched_priority) prio;
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
add_opt(llama_arg(
{"--poll-batch-draft"}, "<0|1>",
"Use polling to wait for draft model work (default: --poll-draft)",
Expand Down Expand Up @@ -1124,45 +1158,45 @@ std::vector<llama_arg> gpt_params_parser_init(gpt_params & params, llama_example
[](gpt_params & params) {
params.interactive = true;
}
).set_examples({LLAMA_EXAMPLE_INFILL}));
).set_examples({LLAMA_EXAMPLE_MAIN}));
add_opt(llama_arg(
{"-if", "--interactive-first"},
format("run in interactive mode and wait for input right away (default: %s)", params.interactive_first ? "true" : "false"),
[](gpt_params & params) {
params.interactive_first = true;
}
).set_examples({LLAMA_EXAMPLE_INFILL}));
).set_examples({LLAMA_EXAMPLE_MAIN}));
add_opt(llama_arg(
{"-mli", "--multiline-input"},
"allows you to write or paste multiple lines without ending each in '\\'",
[](gpt_params & params) {
params.multiline_input = true;
}
).set_examples({LLAMA_EXAMPLE_INFILL}));
).set_examples({LLAMA_EXAMPLE_MAIN}));
add_opt(llama_arg(
{"--in-prefix-bos"},
"prefix BOS to user inputs, preceding the `--in-prefix` string",
[](gpt_params & params) {
params.input_prefix_bos = true;
params.enable_chat_template = false;
}
).set_examples({LLAMA_EXAMPLE_INFILL}));
).set_examples({LLAMA_EXAMPLE_MAIN}));
add_opt(llama_arg(
{"--in-prefix"}, "STRING",
"string to prefix user inputs with (default: empty)",
[](gpt_params & params, const std::string & value) {
params.input_prefix = value;
params.enable_chat_template = false;
}
).set_examples({LLAMA_EXAMPLE_INFILL}));
).set_examples({LLAMA_EXAMPLE_MAIN}));
add_opt(llama_arg(
{"--in-suffix"}, "STRING",
"string to suffix after user inputs with (default: empty)",
[](gpt_params & params, const std::string & value) {
params.input_suffix = value;
params.enable_chat_template = false;
}
).set_examples({LLAMA_EXAMPLE_INFILL}));
).set_examples({LLAMA_EXAMPLE_MAIN}));
add_opt(llama_arg(
{"--no-warmup"},
"skip warming up the model with an empty run",
Expand Down Expand Up @@ -1499,7 +1533,7 @@ std::vector<llama_arg> gpt_params_parser_init(gpt_params & params, llama_example
}
));
add_opt(llama_arg(
{"--all-logits"},
{"--perplexity", "--all-logits"},
format("return logits for all tokens in the batch (default: %s)", params.logits_all ? "true" : "false"),
[](gpt_params & params) {
params.logits_all = true;
Expand Down Expand Up @@ -1554,6 +1588,13 @@ std::vector<llama_arg> gpt_params_parser_init(gpt_params & params, llama_example
params.kl_divergence = true;
}
).set_examples({LLAMA_EXAMPLE_PERPLEXITY}));
add_opt(llama_arg(
{"--save-all-logits", "--kl-divergence-base"}, "FNAME",
"set logits file",
[](gpt_params & params, const std::string & value) {
params.logits_file = value;
}
).set_examples({LLAMA_EXAMPLE_PERPLEXITY}));
add_opt(llama_arg(
{"--ppl-stride"}, "N",
format("stride for perplexity calculation (default: %d)", params.ppl_stride),
Expand Down Expand Up @@ -1656,7 +1697,7 @@ std::vector<llama_arg> gpt_params_parser_init(gpt_params & params, llama_example
}
));
add_opt(llama_arg(
{"-ngl", "--gpu-layers"}, "N",
{"-ngl", "--gpu-layers", "--n-gpu-layers"}, "N",
"number of layers to store in VRAM",
[](gpt_params & params, int value) {
params.n_gpu_layers = value;
Expand All @@ -1667,7 +1708,7 @@ std::vector<llama_arg> gpt_params_parser_init(gpt_params & params, llama_example
}
).set_env("LLAMA_ARG_N_GPU_LAYERS"));
add_opt(llama_arg(
{"-ngld", "--gpu-layers-draft"}, "N",
{"-ngld", "--gpu-layers-draft", "--n-gpu-layers-draft"}, "N",
"number of layers to store in VRAM for the draft model",
[](gpt_params & params, int value) {
params.n_gpu_layers_draft = value;
Expand Down Expand Up @@ -1802,7 +1843,7 @@ std::vector<llama_arg> gpt_params_parser_init(gpt_params & params, llama_example
[](gpt_params & params, const std::string & value) {
params.model_alias = value;
}
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_MODEL"));
).set_examples({LLAMA_EXAMPLE_SERVER}));
add_opt(llama_arg(
{"-m", "--model"}, "FNAME",
ex == LLAMA_EXAMPLE_EXPORT_LORA
Expand Down Expand Up @@ -1890,7 +1931,7 @@ std::vector<llama_arg> gpt_params_parser_init(gpt_params & params, llama_example
}
).set_examples({LLAMA_EXAMPLE_PASSKEY}));
add_opt(llama_arg(
{"-o", "--output"}, "FNAME",
{"-o", "--output", "--output-file"}, "FNAME",
format("output file (default: '%s')",
ex == LLAMA_EXAMPLE_EXPORT_LORA
? params.lora_outfile.c_str()
Expand Down Expand Up @@ -1932,7 +1973,7 @@ std::vector<llama_arg> gpt_params_parser_init(gpt_params & params, llama_example
}
).set_examples({LLAMA_EXAMPLE_IMATRIX}));
add_opt(llama_arg(
{"--chunk"}, "N",
{"--chunk", "--from-chunk"}, "N",
format("start processing the input from chunk N (default: %d)", params.i_chunk),
[](gpt_params & params, int value) {
params.i_chunk = value;
Expand Down Expand Up @@ -2057,7 +2098,7 @@ std::vector<llama_arg> gpt_params_parser_init(gpt_params & params, llama_example
}
).set_examples({LLAMA_EXAMPLE_SERVER}));
add_opt(llama_arg(
{"--timeout"}, "N",
{"-to", "--timeout"}, "N",
format("server read/write timeout in seconds (default: %d)", params.timeout_read),
[](gpt_params & params, int value) {
params.timeout_read = value;
Expand Down
1 change: 0 additions & 1 deletion common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,6 @@ struct gpt_params {
bool use_mlock = false; // use mlock to keep model in memory
bool verbose_prompt = false; // print prompt tokens before generation
bool display_prompt = true; // print prompt before generation
bool infill = false; // use infill mode
bool dump_kv_cache = false; // dump the KV cache contents for debugging purposes
bool no_kv_offload = false; // disable KV offloading
bool warmup = true; // warmup run
Expand Down
5 changes: 0 additions & 5 deletions examples/infill/infill.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -306,11 +306,6 @@ int main(int argc, char ** argv) {
LOG_TEE("\n\n");

LOG_TEE("\n##### Infill mode #####\n\n");
if (params.infill) {
printf("\n************\n");
printf("no need to specify '--infill', always running infill\n");
printf("************\n\n");
}
if (params.interactive) {
const char *control_message;
if (params.multiline_input) {
Expand Down
25 changes: 24 additions & 1 deletion tests/test-arg-parser.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <string>
#include <vector>
#include <sstream>
#include <unordered_set>

#undef NDEBUG
#include <cassert>
Expand All @@ -13,7 +14,29 @@ int main(void) {
printf("test-arg-parser: make sure there is no duplicated arguments in any examples\n\n");
for (int ex = 0; ex < LLAMA_EXAMPLE_COUNT; ex++) {
try {
gpt_params_parser_init(params, (enum llama_example)ex);
auto options = gpt_params_parser_init(params, (enum llama_example)ex);
std::unordered_set<std::string> seen_args;
std::unordered_set<std::string> seen_env_vars;
for (const auto & opt : options) {
// check for args duplications
for (const auto & arg : opt.args) {
if (seen_args.find(arg) == seen_args.end()) {
seen_args.insert(arg);
} else {
fprintf(stderr, "test-arg-parser: found different handlers for the same argument: %s", arg);
exit(1);
}
}
// check for env var duplications
if (opt.env) {
if (seen_env_vars.find(opt.env) == seen_env_vars.end()) {
seen_env_vars.insert(opt.env);
} else {
fprintf(stderr, "test-arg-parser: found different handlers for the same env var: %s", opt.env);
exit(1);
}
}
}
} catch (std::exception & e) {
printf("%s\n", e.what());
assert(false);
Expand Down
Loading