@@ -230,8 +230,8 @@ int main(int argc, char ** argv) {
230
230
fprintf (stderr, " Input prefix: '%s'\n " , params.input_prefix .c_str ());
231
231
}
232
232
}
233
- fprintf (stderr, " sampling: repeat_last_n = %d, repeat_penalty = %f, alpha_presence = %f, alpha_frequency = %f, top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f\n " ,
234
- params.repeat_last_n , params.repeat_penalty , params.alpha_presence , params.alpha_frequency , params.top_k , params.tfs_z , params.top_p , params.typical_p , params.temp );
233
+ fprintf (stderr, " sampling: repeat_last_n = %d, repeat_penalty = %f, alpha_presence = %f, alpha_frequency = %f, top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f, mirostat = %d, mirostat_eta = %f, mirostat_tau = %f \n " ,
234
+ params.repeat_last_n , params.repeat_penalty , params.alpha_presence , params.alpha_frequency , params.top_k , params.tfs_z , params.top_p , params.typical_p , params.temp , params. mirostat , params. mirostat_eta , params. mirostat_tau );
235
235
fprintf (stderr, " generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n " , n_ctx, params.n_batch , params.n_predict , params.n_keep );
236
236
fprintf (stderr, " \n\n " );
237
237
@@ -313,6 +313,9 @@ int main(int argc, char ** argv) {
313
313
const float repeat_penalty = params.repeat_penalty ;
314
314
const float alpha_presence = params.alpha_presence ;
315
315
const float alpha_frequency = params.alpha_frequency ;
316
+ const int mirostat = params.mirostat ;
317
+ const float mirostat_tau = params.mirostat_tau ;
318
+ const float mirostat_eta = params.mirostat_eta ;
316
319
317
320
llama_token id = 0 ;
318
321
@@ -326,47 +329,45 @@ int main(int argc, char ** argv) {
326
329
327
330
std::vector<llama_token_data> candidates;
328
331
candidates.reserve (n_vocab);
329
- for (size_t i = 0 ; i < n_vocab; i++) {
332
+ for (size_t i = 0 ; i < ( size_t ) n_vocab; i++) {
330
333
candidates.emplace_back (i, logits[i], 0 .0f );
331
334
}
332
335
333
- llama_token_data_array candidates_p = { candidates.data (), candidates.size () };
336
+ llama_token_data_array candidates_p = { candidates.data (), candidates.size (), false };
334
337
335
338
// Apply penalties
336
339
auto last_n_repeat = std::min (std::min ((int )last_n_tokens.size (), repeat_last_n), n_ctx);
337
- llama_sample_repetition_penalty (&candidates_p,
340
+ llama_sample_repetition_penalty (ctx, &candidates_p,
338
341
last_n_tokens.data () + last_n_tokens.size () - last_n_repeat,
339
342
last_n_repeat, repeat_penalty);
340
- llama_sample_frequency_and_presence_penalties (&candidates_p,
343
+ llama_sample_frequency_and_presence_penalties (ctx, &candidates_p,
341
344
last_n_tokens.data () + last_n_tokens.size () - last_n_repeat,
342
345
last_n_repeat, alpha_frequency, alpha_presence);
343
346
344
347
345
- #if 1
346
348
if (temp <= 0 ) {
347
349
// Greedy sampling
348
350
id = llama_sample_token_greedy (ctx, &candidates_p);
349
351
} else {
350
- // Temperature sampling
351
- llama_sample_top_k (&candidates_p, top_k);
352
- llama_sample_tail_free (&candidates_p, tfs_z);
353
- llama_sample_typical (&candidates_p, typical_p);
354
- llama_sample_top_p (&candidates_p, top_p);
355
-
356
- llama_sample_temperature (&candidates_p, temp);
357
- // printf("`%d`", candidates_p.size);
358
- id = llama_sample_token (ctx, &candidates_p);
352
+ if (mirostat == 1 ) {
353
+ static float mirostat_mu = 2 .0f * mirostat_tau;
354
+ static int mirostat_k = 40 ;
355
+ const int mirostat_m = 100 ;
356
+ id = llama_sample_token_mirostat (ctx, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, float (n_vocab), &mirostat_k, &mirostat_mu);
357
+ } else if (mirostat == 2 ) {
358
+ static float mirostat_mu = 2 .0f * mirostat_tau;
359
+ id = llama_sample_token_mirostat_v2 (ctx, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu);
360
+ } else {
361
+ // Temperature sampling
362
+ llama_sample_top_k (ctx, &candidates_p, top_k);
363
+ llama_sample_tail_free (ctx, &candidates_p, tfs_z);
364
+ llama_sample_typical (ctx, &candidates_p, typical_p);
365
+ llama_sample_top_p (ctx, &candidates_p, top_p);
366
+ llama_sample_temperature (ctx, &candidates_p, temp);
367
+ id = llama_sample_token (ctx, &candidates_p);
368
+ }
359
369
}
360
- #else
361
- const float tau = 5.0f;
362
- static float mu = 2.0f * tau;
363
- static int k = 40;
364
- const float eta = 0.1f;
365
- const int m = 100;
366
- const float N = n_vocab;
367
- id = llama_sample_mirostat(ctx, &candidates_p, tau, eta, m, N, &k, &mu);
368
- // id = llama_sample_mirostat_v2(ctx, &candidates_p, tau, eta, &mu);
369
- #endif
370
+ // printf("`%d`", candidates_p.size);
370
371
371
372
last_n_tokens.erase (last_n_tokens.begin ());
372
373
last_n_tokens.push_back (id);
0 commit comments