Skip to content

Commit de3784b

Browse files
committed
Opt class for positional argument handling
Added support for positional arguments `MODEL` and `PROMPT`. Signed-off-by: Eric Curtin <[email protected]>
1 parent 76b27d2 commit de3784b

File tree

2 files changed

+85
-96
lines changed

2 files changed

+85
-96
lines changed

examples/main/main.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ int main(int argc, char ** argv) {
153153

154154
// load the model and apply lora adapter, if any
155155
LOG_INF("%s: load the model and apply lora adapter, if any\n", __func__);
156+
std::cout << params << "\n";
156157
common_init_result llama_init = common_init_from_params(params);
157158

158159
model = llama_init.model;

examples/run/run.cpp

Lines changed: 84 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -4,109 +4,106 @@
44
#include <unistd.h>
55
#endif
66

7-
#include <climits>
87
#include <cstdio>
98
#include <cstring>
109
#include <iostream>
1110
#include <sstream>
1211
#include <string>
13-
#include <unordered_map>
1412
#include <vector>
1513

1614
#include "llama-cpp.h"
1715

1816
typedef std::unique_ptr<char[]> char_array_ptr;
1917

20-
struct Argument {
21-
std::string flag;
22-
std::string help_text;
23-
};
24-
25-
struct Options {
26-
std::string model_path, prompt_non_interactive;
27-
int ngl = 99;
28-
int n_ctx = 2048;
29-
};
18+
class Opt {
19+
public:
20+
int init_opt(int argc, const char ** argv) {
21+
construct_help_str_();
22+
// Parse arguments
23+
if (parse(argc, argv)) {
24+
fprintf(stderr, "Error: Failed to parse arguments.\n");
25+
help();
26+
return 1;
27+
}
3028

31-
class ArgumentParser {
32-
public:
33-
ArgumentParser(const char * program_name) : program_name(program_name) {}
29+
// If help is requested, show help and exit
30+
if (help_) {
31+
help();
32+
return 2;
33+
}
3434

35-
void add_argument(const std::string & flag, std::string & var, const std::string & help_text = "") {
36-
string_args[flag] = &var;
37-
arguments.push_back({flag, help_text});
35+
return 0; // Success
3836
}
3937

40-
void add_argument(const std::string & flag, int & var, const std::string & help_text = "") {
41-
int_args[flag] = &var;
42-
arguments.push_back({flag, help_text});
38+
const char * model_ = nullptr;
39+
std::string prompt_;
40+
int context_size_ = 2048, ngl_ = 0;
41+
42+
private:
43+
std::string help_str_;
44+
bool help_ = false;
45+
46+
void construct_help_str_() {
47+
help_str_ =
48+
"Description:\n"
49+
" Runs a llm\n"
50+
"\n"
51+
"Usage:\n"
52+
" llama-run [options] MODEL [PROMPT]\n"
53+
"\n"
54+
"Options:\n"
55+
" -c, --context-size <value>\n"
56+
" Context size (default: " +
57+
std::to_string(context_size_);
58+
help_str_ +=
59+
")\n"
60+
" -n, --ngl <value>\n"
61+
" Number of GPU layers (default: " +
62+
std::to_string(ngl_);
63+
help_str_ +=
64+
")\n"
65+
" -h, --help\n"
66+
" Show help message\n"
67+
"\n"
68+
"Examples:\n"
69+
" llama-run your_model.gguf\n"
70+
" llama-run --ngl 99 your_model.gguf\n"
71+
" llama-run --ngl 99 your_model.gguf Hello World\n";
4372
}
4473

4574
int parse(int argc, const char ** argv) {
75+
int positional_args_i = 0;
4676
for (int i = 1; i < argc; ++i) {
47-
std::string arg = argv[i];
48-
if (string_args.count(arg)) {
49-
if (i + 1 < argc) {
50-
*string_args[arg] = argv[++i];
51-
} else {
52-
fprintf(stderr, "error: missing value for %s\n", arg.c_str());
53-
print_usage();
77+
if (std::strcmp(argv[i], "-c") == 0 || std::strcmp(argv[i], "--context-size") == 0) {
78+
if (i + 1 >= argc) {
5479
return 1;
5580
}
56-
} else if (int_args.count(arg)) {
57-
if (i + 1 < argc) {
58-
if (parse_int_arg(argv[++i], *int_args[arg]) != 0) {
59-
fprintf(stderr, "error: invalid value for %s: %s\n", arg.c_str(), argv[i]);
60-
print_usage();
61-
return 1;
62-
}
63-
} else {
64-
fprintf(stderr, "error: missing value for %s\n", arg.c_str());
65-
print_usage();
81+
82+
context_size_ = std::atoi(argv[++i]);
83+
} else if (std::strcmp(argv[i], "-n") == 0 || std::strcmp(argv[i], "--ngl") == 0) {
84+
if (i + 1 >= argc) {
6685
return 1;
6786
}
87+
88+
ngl_ = std::atoi(argv[++i]);
89+
} else if (std::strcmp(argv[i], "-h") == 0 || std::strcmp(argv[i], "--help") == 0) {
90+
help_ = true;
91+
return 0;
92+
} else if (!positional_args_i) {
93+
++positional_args_i;
94+
model_ = argv[i];
95+
} else if (positional_args_i == 1) {
96+
++positional_args_i;
97+
prompt_ = argv[i];
6898
} else {
69-
fprintf(stderr, "error: unrecognized argument %s\n", arg.c_str());
70-
print_usage();
71-
return 1;
99+
prompt_ += " " + std::string(argv[i]);
72100
}
73101
}
74102

75-
if (string_args["-m"]->empty()) {
76-
fprintf(stderr, "error: -m is required\n");
77-
print_usage();
78-
return 1;
79-
}
80-
81-
return 0;
82-
}
83-
84-
private:
85-
const char * program_name;
86-
std::unordered_map<std::string, std::string *> string_args;
87-
std::unordered_map<std::string, int *> int_args;
88-
std::vector<Argument> arguments;
89-
90-
int parse_int_arg(const char * arg, int & value) {
91-
char * end;
92-
const long val = std::strtol(arg, &end, 10);
93-
if (*end == '\0' && val >= INT_MIN && val <= INT_MAX) {
94-
value = static_cast<int>(val);
95-
return 0;
96-
}
97-
return 1;
103+
return !model_; // model_ is the only required value
98104
}
99105

100-
void print_usage() const {
101-
printf("\nUsage:\n");
102-
printf(" %s [OPTIONS]\n\n", program_name);
103-
printf("Options:\n");
104-
for (const auto & arg : arguments) {
105-
printf(" %-10s %s\n", arg.flag.c_str(), arg.help_text.c_str());
106-
}
107-
108-
printf("\n");
109-
}
106+
void help() const { printf("%s", help_str_.c_str()); }
110107
};
111108

112109
class LlamaData {
@@ -116,13 +113,13 @@ class LlamaData {
116113
llama_context_ptr context;
117114
std::vector<llama_chat_message> messages;
118115

119-
int init(const Options & opt) {
120-
model = initialize_model(opt.model_path, opt.ngl);
116+
int init(const Opt & opt) {
117+
model = initialize_model(opt.model_, opt.ngl_);
121118
if (!model) {
122119
return 1;
123120
}
124121

125-
context = initialize_context(model, opt.n_ctx);
122+
context = initialize_context(model, opt.context_size_);
126123
if (!context) {
127124
return 1;
128125
}
@@ -134,6 +131,7 @@ class LlamaData {
134131
private:
135132
// Initializes the model and returns a unique pointer to it
136133
llama_model_ptr initialize_model(const std::string & model_path, const int ngl) {
134+
ggml_backend_load_all();
137135
llama_model_params model_params = llama_model_default_params();
138136
model_params.n_gpu_layers = ngl;
139137

@@ -273,19 +271,6 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str
273271
return 0;
274272
}
275273

276-
static int parse_arguments(const int argc, const char ** argv, Options & opt) {
277-
ArgumentParser parser(argv[0]);
278-
parser.add_argument("-m", opt.model_path, "model");
279-
parser.add_argument("-p", opt.prompt_non_interactive, "prompt");
280-
parser.add_argument("-c", opt.n_ctx, "context_size");
281-
parser.add_argument("-ngl", opt.ngl, "n_gpu_layers");
282-
if (parser.parse(argc, argv)) {
283-
return 1;
284-
}
285-
286-
return 0;
287-
}
288-
289274
static int read_user_input(std::string & user) {
290275
std::getline(std::cin, user);
291276
return user.empty(); // Indicate an error or empty input
@@ -382,17 +367,20 @@ static std::string read_pipe_data() {
382367
}
383368

384369
int main(int argc, const char ** argv) {
385-
Options opt;
386-
if (parse_arguments(argc, argv, opt)) {
370+
Opt opt;
371+
const int opt_ret = opt.init_opt(argc, argv);
372+
if (opt_ret == 2) {
373+
return 0;
374+
} else if (opt_ret) {
387375
return 1;
388376
}
389377

390378
if (!is_stdin_a_terminal()) {
391-
if (!opt.prompt_non_interactive.empty()) {
392-
opt.prompt_non_interactive += "\n\n";
379+
if (!opt.prompt_.empty()) {
380+
opt.prompt_ += "\n\n";
393381
}
394382

395-
opt.prompt_non_interactive += read_pipe_data();
383+
opt.prompt_ += read_pipe_data();
396384
}
397385

398386
llama_log_set(log_callback, nullptr);
@@ -401,7 +389,7 @@ int main(int argc, const char ** argv) {
401389
return 1;
402390
}
403391

404-
if (chat_loop(llama_data, opt.prompt_non_interactive)) {
392+
if (chat_loop(llama_data, opt.prompt_)) {
405393
return 1;
406394
}
407395

0 commit comments

Comments
 (0)