1
- #include " arg.h"
2
- #include " common.h"
3
- #include " log.h"
4
1
#include " llama.h"
5
-
2
+ #include < cstdio>
3
+ #include < string>
6
4
#include < vector>
7
5
8
6
static void print_usage (int , char ** argv) {
9
- LOG (" \n example usage:\n " );
10
- LOG (" \n %s -m model.gguf -p \" Hello my name is \" -n 32 \n " , argv[0 ]);
11
- LOG (" \n " );
7
+ printf (" \n example usage:\n " );
8
+ printf (" \n %s < model.gguf> [prompt] \n " , argv[0 ]);
9
+ printf (" \n " );
12
10
}
13
11
14
12
int main (int argc, char ** argv) {
15
- gpt_params params;
16
-
17
- params.prompt = " Hello my name is" ;
18
- params.n_predict = 32 ;
13
+ std::string model_path;
14
+ std::string prompt = " Hello my name is" ;
15
+ int n_predict = 32 ;
19
16
20
- if (!gpt_params_parse (argc, argv, params, LLAMA_EXAMPLE_COMMON, print_usage)) {
17
+ if (argc < 2 ) {
18
+ print_usage (argc, argv);
21
19
return 1 ;
22
20
}
21
+ model_path = argv[1 ];
23
22
24
- gpt_init ();
25
-
26
- // total length of the sequence including the prompt
27
- const int n_predict = params.n_predict ;
28
-
29
- // init LLM
30
-
31
- llama_backend_init ();
32
- llama_numa_init (params.numa );
23
+ if (argc > 2 ) {
24
+ prompt = argv[2 ];
25
+ for (int i = 3 ; i < argc; i++) {
26
+ prompt += " " ;
27
+ prompt += argv[i];
28
+ }
29
+ }
33
30
34
31
// initialize the model
35
32
36
- llama_model_params model_params = llama_model_params_from_gpt_params (params );
37
-
38
- llama_model * model = llama_load_model_from_file (params. model .c_str (), model_params);
33
+ llama_model_params model_params = llama_model_default_params ( );
34
+ model_params. n_gpu_layers = 99 ; // offload all layers to GPU
35
+ llama_model * model = llama_load_model_from_file (model_path .c_str (), model_params);
39
36
40
37
if (model == NULL ) {
41
38
fprintf (stderr , " %s: error: unable to load model\n " , __func__);
@@ -44,8 +41,9 @@ int main(int argc, char ** argv) {
44
41
45
42
// initialize the context
46
43
47
- llama_context_params ctx_params = llama_context_params_from_gpt_params (params);
48
-
44
+ llama_context_params ctx_params = llama_context_default_params ();
45
+ ctx_params.n_ctx = 512 ; // maximum context size
46
+ ctx_params.no_perf = false ;
49
47
llama_context * ctx = llama_new_context_with_model (model, ctx_params);
50
48
51
49
if (ctx == NULL ) {
@@ -54,54 +52,58 @@ int main(int argc, char ** argv) {
54
52
}
55
53
56
54
auto sparams = llama_sampler_chain_default_params ();
57
-
58
55
sparams.no_perf = false ;
59
-
60
56
llama_sampler * smpl = llama_sampler_chain_init (sparams);
61
57
62
58
llama_sampler_chain_add (smpl, llama_sampler_init_greedy ());
63
59
64
60
// tokenize the prompt
65
61
66
62
std::vector<llama_token> tokens_list;
67
- tokens_list = ::llama_tokenize (ctx, params.prompt , true );
63
+ int n_tokens = llama_tokenize (model, prompt.c_str (), prompt.size (), NULL , 0 , true , true );
64
+ tokens_list.resize (-n_tokens);
65
+ if (llama_tokenize (model, prompt.c_str (), prompt.size (), tokens_list.data (), tokens_list.size (), true , true ) < 0 ) {
66
+ fprintf (stderr, " %s: error: failed to tokenize the prompt\n " , __func__);
67
+ return 1 ;
68
+ }
68
69
69
70
const int n_ctx = llama_n_ctx (ctx);
70
71
const int n_kv_req = tokens_list.size () + (n_predict - tokens_list.size ());
71
72
72
- LOG ( " \n " );
73
- LOG_INF ( " %s: n_predict = %d, n_ctx = %d, n_kv_req = %d \n " , __func__, n_predict, n_ctx, n_kv_req);
73
+ fprintf (stderr, " %s: n_predict = %d, n_ctx = %d, n_kv_req = %d \n " , __func__, n_predict, n_ctx, n_kv_req );
74
+
74
75
75
76
// make sure the KV cache is big enough to hold all the prompt and generated tokens
76
77
if (n_kv_req > n_ctx) {
77
- LOG_ERR ( " %s: error: n_kv_req > n_ctx, the required KV cache size is not big enough\n " , __func__);
78
- LOG_ERR ( " %s: either reduce n_predict or increase n_ctx\n " , __func__);
78
+ fprintf (stderr, " %s: error: n_kv_req > n_ctx, the required KV cache size is not big enough\n " , __func__);
79
+ fprintf (stderr, " %s: either reduce n_predict or increase n_ctx\n " , __func__);
79
80
return 1 ;
80
81
}
81
82
82
83
// print the prompt token-by-token
83
84
84
- LOG ( " \n " );
85
+ fprintf (stderr, " \n " );
85
86
86
87
for (auto id : tokens_list) {
87
- LOG (" %s" , llama_token_to_piece (ctx, id).c_str ());
88
+ char buf[128 ];
89
+ int n = llama_token_to_piece (model, id, buf, sizeof (buf), 0 , true );
90
+ if (n < 0 ) {
91
+ fprintf (stderr, " %s: error: failed to convert token to piece\n " , __func__);
92
+ return 1 ;
93
+ }
94
+ std::string s (buf, n);
95
+ printf (" %s" , s.c_str ());
88
96
}
89
97
90
98
// create a llama_batch with size 512
91
99
// we use this object to submit token data for decoding
92
100
93
- llama_batch batch = llama_batch_init ( 512 , 0 , 1 );
101
+ llama_batch batch = llama_batch_get_one (tokens_list. data (), tokens_list. size (), 0 , 0 );
94
102
95
103
// evaluate the initial prompt
96
- for (size_t i = 0 ; i < tokens_list.size (); i++) {
97
- llama_batch_add (batch, tokens_list[i], i, { 0 }, false );
98
- }
99
-
100
- // llama_decode will output logits only for the last token of the prompt
101
- batch.logits [batch.n_tokens - 1 ] = true ;
102
104
103
105
if (llama_decode (ctx, batch) != 0 ) {
104
- LOG ( " %s: llama_decode() failed\n " , __func__);
106
+ fprintf (stderr, " %s: llama_decode() failed\n " , __func__);
105
107
return 1 ;
106
108
}
107
109
@@ -114,24 +116,28 @@ int main(int argc, char ** argv) {
114
116
115
117
while (n_cur <= n_predict) {
116
118
// sample the next token
119
+ llama_token new_token_id = llama_sampler_sample (smpl, ctx, -1 );
117
120
{
118
- const llama_token new_token_id = llama_sampler_sample (smpl, ctx, -1 );
119
121
120
122
// is it an end of generation?
121
123
if (llama_token_is_eog (model, new_token_id) || n_cur == n_predict) {
122
- LOG ( " \n " );
124
+ fprintf (stderr, " \n " );
123
125
124
126
break ;
125
127
}
126
128
127
- LOG (" %s" , llama_token_to_piece (ctx, new_token_id).c_str ());
129
+ char buf[128 ];
130
+ int n = llama_token_to_piece (model, new_token_id, buf, sizeof (buf), 0 , true );
131
+ if (n < 0 ) {
132
+ fprintf (stderr, " %s: error: failed to convert token to piece\n " , __func__);
133
+ return 1 ;
134
+ }
135
+ std::string s (buf, n);
136
+ printf (" %s" , s.c_str ());
128
137
fflush (stdout);
129
138
130
139
// prepare the next batch
131
- llama_batch_clear (batch);
132
-
133
- // push this new token for next evaluation
134
- llama_batch_add (batch, new_token_id, n_cur, { 0 }, true );
140
+ batch = llama_batch_get_one (&new_token_id, 1 , n_cur, 0 );
135
141
136
142
n_decode += 1 ;
137
143
}
@@ -140,30 +146,26 @@ int main(int argc, char ** argv) {
140
146
141
147
// evaluate the current batch with the transformer model
142
148
if (llama_decode (ctx, batch)) {
143
- LOG_ERR ( " %s : failed to eval, return code %d\n " , __func__, 1 );
149
+ fprintf (stderr, " %s : failed to eval, return code %d\n " , __func__, 1 );
144
150
return 1 ;
145
151
}
146
152
}
147
153
148
- LOG ( " \n " );
154
+ fprintf (stderr, " \n " );
149
155
150
156
const auto t_main_end = ggml_time_us ();
151
157
152
- LOG_INF ( " %s: decoded %d tokens in %.2f s, speed: %.2f t/s\n " ,
158
+ fprintf (stderr, " %s: decoded %d tokens in %.2f s, speed: %.2f t/s\n " ,
153
159
__func__, n_decode, (t_main_end - t_main_start) / 1000000 .0f , n_decode / ((t_main_end - t_main_start) / 1000000 .0f ));
154
160
155
- LOG ( " \n " );
161
+ fprintf (stderr, " \n " );
156
162
llama_perf_sampler_print (smpl);
157
163
llama_perf_context_print (ctx);
164
+ fprintf (stderr, " \n " );
158
165
159
- LOG (" \n " );
160
-
161
- llama_batch_free (batch);
162
166
llama_sampler_free (smpl);
163
167
llama_free (ctx);
164
168
llama_free_model (model);
165
169
166
- llama_backend_free ();
167
-
168
170
return 0 ;
169
171
}
0 commit comments