@@ -32,12 +32,18 @@ int main(int argc, char ** argv) {
32
32
params.prompt = " Hello my name is" ;
33
33
}
34
34
35
+ // total length of the sequences including the prompt
36
+ const int n_len = 32 ;
37
+
35
38
// init LLM
36
39
37
40
llama_backend_init (params.numa );
38
41
39
42
llama_context_params ctx_params = llama_context_default_params ();
40
43
44
+ ctx_params.seed = 1234 ;
45
+ ctx_params.n_ctx = 2048 ;
46
+
41
47
llama_model * model = llama_load_model_from_file (params.model .c_str (), ctx_params);
42
48
43
49
if (model == NULL ) {
@@ -47,87 +53,187 @@ int main(int argc, char ** argv) {
47
53
48
54
llama_context * ctx = llama_new_context_with_model (model, ctx_params);
49
55
56
+ if (ctx == NULL ) {
57
+ fprintf (stderr , " %s: error: failed to create the llama_context\n " , __func__);
58
+ return 1 ;
59
+ }
60
+
50
61
// tokenize the prompt
51
62
52
63
std::vector<llama_token> tokens_list;
53
64
tokens_list = ::llama_tokenize (ctx, params.prompt , true );
54
65
55
- const int max_context_size = llama_n_ctx (ctx);
56
- const int max_tokens_list_size = max_context_size - 4 ;
66
+ const int n_ctx = llama_n_ctx (ctx);
67
+ const int n_kv_req = tokens_list. size () + (n_len - tokens_list. size ())*n_parallel ;
57
68
58
- if ((int ) tokens_list.size () > max_tokens_list_size) {
59
- fprintf (stderr, " %s: error: prompt too long (%d tokens, max %d)\n " , __func__, (int ) tokens_list.size (), max_tokens_list_size);
69
+ LOG_TEE (" \n %s: n_len = %d, n_ctx = %d, n_parallel = %d, n_kv_req = %d\n " , __func__, n_len, n_ctx, n_parallel, n_kv_req);
70
+
71
+ // make sure wi
72
+ if (n_kv_req > n_ctx) {
73
+ LOG_TEE (" %s: error: n_kv_req > n_ctx, the required KV cache size is not big enough\n " , __func__);
74
+ LOG_TEE (" %s: either reduce n_parallel or increase n_ctx\n " , __func__);
60
75
return 1 ;
61
76
}
62
77
63
- fprintf (stderr, " \n\n " );
78
+ fprintf (stderr, " \n " );
64
79
65
80
for (auto id : tokens_list) {
66
81
fprintf (stderr, " %s" , llama_token_to_piece (ctx, id).c_str ());
67
82
}
68
83
69
84
fflush (stderr);
70
85
86
+ // create a llama_batch with size 512
87
+ // we use this object to submit token data for decoding
88
+
89
+ llama_batch batch = llama_batch_init (512 , 0 );
90
+
91
+ // evaluate the initial prompt
92
+ batch.n_tokens = tokens_list.size ();
93
+
94
+ for (int32_t i = 0 ; i < batch.n_tokens ; i++) {
95
+ batch.token [i] = tokens_list[i];
96
+ batch.pos [i] = i;
97
+ batch.seq_id [i] = 0 ;
98
+ batch.logits [i] = false ;
99
+ }
100
+
101
+ // llama_decode will output logits only for the last token of the prompt
102
+ batch.logits [batch.n_tokens - 1 ] = true ;
103
+
104
+ if (llama_decode (ctx, batch, params.n_threads ) != 0 ) {
105
+ LOG_TEE (" %s: llama_decode() failed\n " , __func__);
106
+ return 1 ;
107
+ }
108
+
109
+ // assign the system KV cache to all parallel sequences
110
+ for (int32_t i = 1 ; i < n_parallel; ++i) {
111
+ llama_kv_cache_seq_cp (ctx, 0 , i, 0 , batch.n_tokens );
112
+ }
113
+
114
+ if (n_parallel > 1 ) {
115
+ LOG_TEE (" \n\n %s: generating %d sequences ...\n " , __func__, n_parallel);
116
+ }
117
+
71
118
// main loop
72
119
73
- // The LLM keeps a contextual cache memory of previous token evaluation.
74
- // Usually, once this cache is full, it is required to recompute a compressed context based on previous
75
- // tokens (see "infinite text generation via context swapping" in the main example), but in this minimalist
76
- // example, we will just stop the loop once this cache is full or once an end of stream is detected.
120
+ // we will store the parallel decoded sequences in this vector
121
+ std::vector<std::string> streams (n_parallel);
77
122
78
- const int n_gen = std::min (32 , max_context_size);
123
+ // remember the batch index of the last tokenn for each parallel sequence
124
+ // we will use this to know which logits to sample from
125
+ std::vector<int32_t > i_batch (n_parallel, batch.n_tokens - 1 );
79
126
80
- int n_cur = 0 ;
127
+ int n_cur = batch.n_tokens ;
128
+ int n_decode = 0 ;
81
129
82
- while (n_cur < n_gen) {
83
- // evaluate the transformer
130
+ const auto t_main_start = ggml_time_us ();
84
131
85
- if (llama_decode (ctx, llama_batch_get_one (tokens_list.data (), int (tokens_list.size ()), n_cur, 0 ), params.n_threads )) {
86
- fprintf (stderr, " %s : failed to eval\n " , __func__);
132
+ while (n_cur <= n_len) {
133
+ // evaluate the current batch with the transformer model
134
+ if (llama_decode (ctx, batch, params.n_threads )) {
135
+ fprintf (stderr, " %s : failed to eval, return code %d\n " , __func__, 1 );
87
136
return 1 ;
88
137
}
89
138
90
- n_cur += tokens_list. size ();
91
- tokens_list. clear () ;
139
+ // prepare the next batch
140
+ batch. n_tokens = 0 ;
92
141
93
- // sample the next token
142
+ // sample the next token for each parallel sequence / stream
143
+ for (int32_t i = 0 ; i < n_parallel; ++i) {
144
+ if (i_batch[i] < 0 ) {
145
+ // the stream has already finished
146
+ continue ;
147
+ }
94
148
95
- llama_token new_token_id = 0 ;
149
+ auto n_vocab = llama_n_vocab (ctx);
150
+ auto logits = llama_get_logits (ctx) + i_batch[i] * n_vocab;
96
151
97
- auto logits = llama_get_logits (ctx) ;
98
- auto n_vocab = llama_n_vocab (ctx );
152
+ std::vector<llama_token_data> candidates ;
153
+ candidates. reserve (n_vocab );
99
154
100
- std::vector<llama_token_data> candidates;
101
- candidates.reserve (n_vocab);
155
+ for (llama_token token_id = 0 ; token_id < n_vocab; token_id++) {
156
+ candidates.emplace_back (llama_token_data{ token_id, logits[token_id], 0 .0f });
157
+ }
102
158
103
- for (llama_token token_id = 0 ; token_id < n_vocab; token_id++) {
104
- candidates.emplace_back (llama_token_data{ token_id, logits[token_id], 0 .0f });
105
- }
159
+ llama_token_data_array candidates_p = { candidates.data (), candidates.size (), false };
106
160
107
- llama_token_data_array candidates_p = { candidates.data (), candidates.size (), false };
161
+ const int top_k = 40 ;
162
+ const float top_p = 0 .9f ;
163
+ const float temp = 0 .4f ;
108
164
109
- new_token_id = llama_sample_token_greedy (ctx , &candidates_p);
165
+ llama_sample_top_k (ctx, &candidates_p, top_k, 1 );
166
+ llama_sample_top_p (ctx, &candidates_p, top_p, 1 );
167
+ llama_sample_temp (ctx, &candidates_p, temp);
110
168
111
- // is it an end of stream ?
112
- if (new_token_id == llama_token_eos (ctx)) {
113
- fprintf (stderr, " [end of text]\n " );
169
+ const llama_token new_token_id = llama_sample_token (ctx, &candidates_p);
170
+
171
+ // const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p);
172
+
173
+ // is it an end of stream ?
174
+ // mark this stream as finished
175
+ if (new_token_id == llama_token_eos (ctx) || n_cur == n_len) {
176
+ i_batch[i] = -1 ;
177
+ LOG_TEE (" \n " );
178
+ if (n_parallel > 1 ) {
179
+ LOG_TEE (" %s: stream %d finished" , __func__, i);
180
+ }
181
+
182
+ continue ;
183
+ }
184
+
185
+ if (n_parallel == 1 ) {
186
+ // print the new token :
187
+ LOG_TEE (" %s" , llama_token_to_piece (ctx, new_token_id).c_str ());
188
+ fflush (stdout);
189
+ }
190
+
191
+ streams[i] += llama_token_to_piece (ctx, new_token_id);
192
+
193
+ // push this new token for next evaluation
194
+ batch.token [batch.n_tokens ] = new_token_id;
195
+ batch.pos [batch.n_tokens ] = n_cur;
196
+ batch.seq_id [batch.n_tokens ] = i;
197
+ batch.logits [batch.n_tokens ] = true ;
198
+
199
+ i_batch[i] = batch.n_tokens ;
200
+
201
+ batch.n_tokens += 1 ;
202
+
203
+ n_decode += 1 ;
204
+ }
205
+
206
+ if (batch.n_tokens == 0 ) {
207
+ // all streams are finished
114
208
break ;
115
209
}
116
210
117
- // print the new token :
118
- printf (" %s" , llama_token_to_piece (ctx, new_token_id).c_str ());
119
- fflush (stdout);
211
+ n_cur += 1 ;
212
+ }
213
+
214
+ LOG_TEE (" \n " );
120
215
121
- // push this new token for next evaluation
122
- tokens_list.push_back (new_token_id);
216
+ if (n_parallel > 1 ) {
217
+ LOG_TEE (" \n " );
218
+
219
+ for (int32_t i = 0 ; i < n_parallel; ++i) {
220
+ LOG_TEE (" sequence %d:\n\n %s%s\n\n " , i, params.prompt .c_str (), streams[i].c_str ());
221
+ }
123
222
}
124
223
224
+ const auto t_main_end = ggml_time_us ();
225
+
226
+ LOG_TEE (" %s: decoded %d tokens in %.2f s, speed: %.2f t/s\n " ,
227
+ __func__, n_decode, (t_main_end - t_main_start) / 1000000 .0f , n_decode / ((t_main_end - t_main_start) / 1000000 .0f ));
228
+
229
+ llama_print_timings (ctx);
230
+
231
+ fprintf (stderr, " \n " );
232
+
125
233
llama_free (ctx);
126
234
llama_free_model (model);
127
235
128
236
llama_backend_free ();
129
237
130
- fprintf (stderr, " \n\n " );
131
-
132
238
return 0 ;
133
239
}
0 commit comments