1
- #include " arg.h"
2
- #include " common.h"
3
- #include " log.h"
4
1
#include " llama.h"
5
-
2
+ #include < cstdio>
3
+ #include < cstring>
4
+ #include < string>
6
5
#include < vector>
7
6
8
7
static void print_usage (int , char ** argv) {
9
- LOG (" \n example usage:\n " );
10
- LOG (" \n %s -m model.gguf -p \" Hello my name is \" -n 32 \n " , argv[0 ]);
11
- LOG (" \n " );
8
+ printf (" \n example usage:\n " );
9
+ printf (" \n %s -m model.gguf [-n n_predict] [-ngl n_gpu_layers] [prompt] \n " , argv[0 ]);
10
+ printf (" \n " );
12
11
}
13
12
14
13
int main (int argc, char ** argv) {
15
- gpt_params params;
16
-
17
- params.prompt = " Hello my name is" ;
18
- params.n_predict = 32 ;
19
-
20
- if (!gpt_params_parse (argc, argv, params, LLAMA_EXAMPLE_COMMON, print_usage)) {
21
- return 1 ;
14
+ // path to the model gguf file
15
+ std::string model_path;
16
+ // prompt to generate text from
17
+ std::string prompt = " Hello my name is" ;
18
+ // number of layers to offload to the GPU
19
+ int ngl = 99 ;
20
+ // number of tokens to predict
21
+ int n_predict = 32 ;
22
+
23
+ // parse command line arguments
24
+
25
+ {
26
+ int i = 1 ;
27
+ for (; i < argc; i++) {
28
+ if (strcmp (argv[i], " -m" ) == 0 ) {
29
+ if (i + 1 < argc) {
30
+ model_path = argv[++i];
31
+ } else {
32
+ print_usage (argc, argv);
33
+ return 1 ;
34
+ }
35
+ } else if (strcmp (argv[i], " -n" ) == 0 ) {
36
+ if (i + 1 < argc) {
37
+ try {
38
+ n_predict = std::stoi (argv[++i]);
39
+ } catch (...) {
40
+ print_usage (argc, argv);
41
+ return 1 ;
42
+ }
43
+ } else {
44
+ print_usage (argc, argv);
45
+ return 1 ;
46
+ }
47
+ } else if (strcmp (argv[i], " -ngl" ) == 0 ) {
48
+ if (i + 1 < argc) {
49
+ try {
50
+ ngl = std::stoi (argv[++i]);
51
+ } catch (...) {
52
+ print_usage (argc, argv);
53
+ return 1 ;
54
+ }
55
+ } else {
56
+ print_usage (argc, argv);
57
+ return 1 ;
58
+ }
59
+ } else {
60
+ // prompt starts here
61
+ break ;
62
+ }
63
+ }
64
+ if (model_path.empty ()) {
65
+ print_usage (argc, argv);
66
+ return 1 ;
67
+ }
68
+ if (i < argc) {
69
+ prompt = argv[i++];
70
+ for (; i < argc; i++) {
71
+ prompt += " " ;
72
+ prompt += argv[i];
73
+ }
74
+ }
22
75
}
23
76
24
- gpt_init ();
25
-
26
- // total length of the sequence including the prompt
27
- const int n_predict = params.n_predict ;
28
-
29
- // init LLM
30
-
31
- llama_backend_init ();
32
- llama_numa_init (params.numa );
33
-
34
77
// initialize the model
35
78
36
- llama_model_params model_params = llama_model_params_from_gpt_params (params);
79
+ llama_model_params model_params = llama_model_default_params ();
80
+ model_params.n_gpu_layers = ngl;
37
81
38
- llama_model * model = llama_load_model_from_file (params. model .c_str (), model_params);
82
+ llama_model * model = llama_load_model_from_file (model_path .c_str (), model_params);
39
83
40
84
if (model == NULL ) {
41
85
fprintf (stderr , " %s: error: unable to load model\n " , __func__);
42
86
return 1 ;
43
87
}
44
88
89
+ // tokenize the prompt
90
+
91
+ // find the number of tokens in the prompt
92
+ const int n_prompt = -llama_tokenize (model, prompt.c_str (), prompt.size (), NULL , 0 , true , true );
93
+
94
+ // allocate space for the tokens and tokenize the prompt
95
+ std::vector<llama_token> prompt_tokens (n_prompt);
96
+ if (llama_tokenize (model, prompt.c_str (), prompt.size (), prompt_tokens.data (), prompt_tokens.size (), true , true ) < 0 ) {
97
+ fprintf (stderr, " %s: error: failed to tokenize the prompt\n " , __func__);
98
+ return 1 ;
99
+ }
100
+
45
101
// initialize the context
46
102
47
- llama_context_params ctx_params = llama_context_params_from_gpt_params (params);
103
+ llama_context_params ctx_params = llama_context_default_params ();
104
+ // n_ctx is the context size
105
+ ctx_params.n_ctx = n_prompt + n_predict - 1 ;
106
+ // n_batch is the maximum number of tokens that can be processed in a single call to llama_decode
107
+ ctx_params.n_batch = n_prompt;
108
+ // enable performance counters
109
+ ctx_params.no_perf = false ;
48
110
49
111
llama_context * ctx = llama_new_context_with_model (model, ctx_params);
50
112
@@ -53,117 +115,87 @@ int main(int argc, char ** argv) {
53
115
return 1 ;
54
116
}
55
117
56
- auto sparams = llama_sampler_chain_default_params ();
118
+ // initialize the sampler
57
119
120
+ auto sparams = llama_sampler_chain_default_params ();
58
121
sparams.no_perf = false ;
59
-
60
122
llama_sampler * smpl = llama_sampler_chain_init (sparams);
61
123
62
124
llama_sampler_chain_add (smpl, llama_sampler_init_greedy ());
63
125
64
- // tokenize the prompt
65
-
66
- std::vector<llama_token> tokens_list;
67
- tokens_list = ::llama_tokenize (ctx, params.prompt , true );
68
-
69
- const int n_ctx = llama_n_ctx (ctx);
70
- const int n_kv_req = tokens_list.size () + (n_predict - tokens_list.size ());
71
-
72
- LOG (" \n " );
73
- LOG_INF (" %s: n_predict = %d, n_ctx = %d, n_kv_req = %d\n " , __func__, n_predict, n_ctx, n_kv_req);
74
-
75
- // make sure the KV cache is big enough to hold all the prompt and generated tokens
76
- if (n_kv_req > n_ctx) {
77
- LOG_ERR (" %s: error: n_kv_req > n_ctx, the required KV cache size is not big enough\n " , __func__);
78
- LOG_ERR (" %s: either reduce n_predict or increase n_ctx\n " , __func__);
79
- return 1 ;
80
- }
81
-
82
126
// print the prompt token-by-token
83
127
84
- LOG (" \n " );
85
-
86
- for (auto id : tokens_list) {
87
- LOG (" %s" , llama_token_to_piece (ctx, id).c_str ());
88
- }
89
-
90
- // create a llama_batch with size 512
91
- // we use this object to submit token data for decoding
92
-
93
- llama_batch batch = llama_batch_init (512 , 0 , 1 );
94
-
95
- // evaluate the initial prompt
96
- for (size_t i = 0 ; i < tokens_list.size (); i++) {
97
- llama_batch_add (batch, tokens_list[i], i, { 0 }, false );
128
+ for (auto id : prompt_tokens) {
129
+ char buf[128 ];
130
+ int n = llama_token_to_piece (model, id, buf, sizeof (buf), 0 , true );
131
+ if (n < 0 ) {
132
+ fprintf (stderr, " %s: error: failed to convert token to piece\n " , __func__);
133
+ return 1 ;
134
+ }
135
+ std::string s (buf, n);
136
+ printf (" %s" , s.c_str ());
98
137
}
99
138
100
- // llama_decode will output logits only for the last token of the prompt
101
- batch.logits [batch.n_tokens - 1 ] = true ;
139
+ // prepare a batch for the prompt
102
140
103
- if (llama_decode (ctx, batch) != 0 ) {
104
- LOG (" %s: llama_decode() failed\n " , __func__);
105
- return 1 ;
106
- }
141
+ llama_batch batch = llama_batch_get_one (prompt_tokens.data (), prompt_tokens.size (), 0 , 0 );
107
142
108
143
// main loop
109
144
110
- int n_cur = batch. n_tokens ;
145
+ const auto t_main_start = ggml_time_us () ;
111
146
int n_decode = 0 ;
147
+ llama_token new_token_id;
112
148
113
- const auto t_main_start = ggml_time_us ();
149
+ for (int n_pos = 0 ; n_pos + batch.n_tokens < n_prompt + n_predict; ) {
150
+ // evaluate the current batch with the transformer model
151
+ if (llama_decode (ctx, batch)) {
152
+ fprintf (stderr, " %s : failed to eval, return code %d\n " , __func__, 1 );
153
+ return 1 ;
154
+ }
155
+
156
+ n_pos += batch.n_tokens ;
114
157
115
- while (n_cur <= n_predict) {
116
158
// sample the next token
117
159
{
118
- const llama_token new_token_id = llama_sampler_sample (smpl, ctx, -1 );
160
+ new_token_id = llama_sampler_sample (smpl, ctx, -1 );
119
161
120
162
// is it an end of generation?
121
- if (llama_token_is_eog (model, new_token_id) || n_cur == n_predict) {
122
- LOG (" \n " );
123
-
163
+ if (llama_token_is_eog (model, new_token_id)) {
124
164
break ;
125
165
}
126
166
127
- LOG (" %s" , llama_token_to_piece (ctx, new_token_id).c_str ());
167
+ char buf[128 ];
168
+ int n = llama_token_to_piece (model, new_token_id, buf, sizeof (buf), 0 , true );
169
+ if (n < 0 ) {
170
+ fprintf (stderr, " %s: error: failed to convert token to piece\n " , __func__);
171
+ return 1 ;
172
+ }
173
+ std::string s (buf, n);
174
+ printf (" %s" , s.c_str ());
128
175
fflush (stdout);
129
176
130
- // prepare the next batch
131
- llama_batch_clear (batch);
132
-
133
- // push this new token for next evaluation
134
- llama_batch_add (batch, new_token_id, n_cur, { 0 }, true );
177
+ // prepare the next batch with the sampled token
178
+ batch = llama_batch_get_one (&new_token_id, 1 , n_pos, 0 );
135
179
136
180
n_decode += 1 ;
137
181
}
138
-
139
- n_cur += 1 ;
140
-
141
- // evaluate the current batch with the transformer model
142
- if (llama_decode (ctx, batch)) {
143
- LOG_ERR (" %s : failed to eval, return code %d\n " , __func__, 1 );
144
- return 1 ;
145
- }
146
182
}
147
183
148
- LOG (" \n " );
184
+ printf (" \n " );
149
185
150
186
const auto t_main_end = ggml_time_us ();
151
187
152
- LOG_INF ( " %s: decoded %d tokens in %.2f s, speed: %.2f t/s\n " ,
188
+ fprintf (stderr, " %s: decoded %d tokens in %.2f s, speed: %.2f t/s\n " ,
153
189
__func__, n_decode, (t_main_end - t_main_start) / 1000000 .0f , n_decode / ((t_main_end - t_main_start) / 1000000 .0f ));
154
190
155
- LOG ( " \n " );
191
+ fprintf (stderr, " \n " );
156
192
llama_perf_sampler_print (smpl);
157
193
llama_perf_context_print (ctx);
194
+ fprintf (stderr, " \n " );
158
195
159
- LOG (" \n " );
160
-
161
- llama_batch_free (batch);
162
196
llama_sampler_free (smpl);
163
197
llama_free (ctx);
164
198
llama_free_model (model);
165
199
166
- llama_backend_free ();
167
-
168
200
return 0 ;
169
201
}
0 commit comments