@@ -68,13 +68,15 @@ int main(int argc, char ** argv) {
68
68
69
69
LOG_TEE (" \n %s: n_len = %d, n_ctx = %d, n_parallel = %d, n_kv_req = %d\n " , __func__, n_len, n_ctx, n_parallel, n_kv_req);
70
70
71
- // make sure wi
71
+ // make sure the KV cache is big enough to hold all the prompt and generated tokens
72
72
if (n_kv_req > n_ctx) {
73
73
LOG_TEE (" %s: error: n_kv_req > n_ctx, the required KV cache size is not big enough\n " , __func__);
74
74
LOG_TEE (" %s: either reduce n_parallel or increase n_ctx\n " , __func__);
75
75
return 1 ;
76
76
}
77
77
78
+ // print the prompt token-by-token
79
+
78
80
fprintf (stderr, " \n " );
79
81
80
82
for (auto id : tokens_list) {
@@ -107,6 +109,7 @@ int main(int argc, char ** argv) {
107
109
}
108
110
109
111
// assign the system KV cache to all parallel sequences
112
+ // this way, the parallel sequences will "reuse" the prompt tokens without having to copy them
110
113
for (int32_t i = 1 ; i < n_parallel; ++i) {
111
114
llama_kv_cache_seq_cp (ctx, 0 , i, 0 , batch.n_tokens );
112
115
}
@@ -120,8 +123,8 @@ int main(int argc, char ** argv) {
120
123
// we will store the parallel decoded sequences in this vector
121
124
std::vector<std::string> streams (n_parallel);
122
125
123
- // remember the batch index of the last tokenn for each parallel sequence
124
- // we will use this to know which logits to sample from
126
+ // remember the batch index of the last token for each parallel sequence
127
+ // we need this to determine which logits to sample from
125
128
std::vector<int32_t > i_batch (n_parallel, batch.n_tokens - 1 );
126
129
127
130
int n_cur = batch.n_tokens ;
@@ -170,8 +173,7 @@ int main(int argc, char ** argv) {
170
173
171
174
// const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p);
172
175
173
- // is it an end of stream ?
174
- // mark this stream as finished
176
+ // is it an end of stream? -> mark the stream as finished
175
177
if (new_token_id == llama_token_eos (ctx) || n_cur == n_len) {
176
178
i_batch[i] = -1 ;
177
179
LOG_TEE (" \n " );
@@ -182,8 +184,8 @@ int main(int argc, char ** argv) {
182
184
continue ;
183
185
}
184
186
187
+ // if there is only one stream, we print immediately to stdout
185
188
if (n_parallel == 1 ) {
186
- // print the new token :
187
189
LOG_TEE (" %s" , llama_token_to_piece (ctx, new_token_id).c_str ());
188
190
fflush (stdout);
189
191
}
@@ -203,8 +205,8 @@ int main(int argc, char ** argv) {
203
205
n_decode += 1 ;
204
206
}
205
207
208
+ // all streams are finished
206
209
if (batch.n_tokens == 0 ) {
207
- // all streams are finished
208
210
break ;
209
211
}
210
212
@@ -230,6 +232,8 @@ int main(int argc, char ** argv) {
230
232
231
233
fprintf (stderr, " \n " );
232
234
235
+ llama_batch_free (batch);
236
+
233
237
llama_free (ctx);
234
238
llama_free_model (model);
235
239
0 commit comments