@@ -276,8 +276,8 @@ int main(int argc, char ** argv) {
276
276
fprintf (stderr, " Input prefix: '%s'\n " , params.input_prefix .c_str ());
277
277
}
278
278
}
279
- fprintf (stderr, " sampling: repeat_last_n = %d, repeat_penalty = %f, alpha_presence = %f, alpha_frequency = %f, top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f\n " ,
280
- params.repeat_last_n , params.repeat_penalty , params.alpha_presence , params.alpha_frequency , params.top_k , params.tfs_z , params.top_p , params.typical_p , params.temp );
279
+ fprintf (stderr, " sampling: repeat_last_n = %d, repeat_penalty = %f, alpha_presence = %f, alpha_frequency = %f, top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f, mirostat = %d, mirostat_eta = %f, mirostat_tau = %f \n " ,
280
+ params.repeat_last_n , params.repeat_penalty , params.alpha_presence , params.alpha_frequency , params.top_k , params.tfs_z , params.top_p , params.typical_p , params.temp , params. mirostat , params. mirostat_eta , params. mirostat_tau );
281
281
fprintf (stderr, " generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n " , n_ctx, params.n_batch , params.n_predict , params.n_keep );
282
282
fprintf (stderr, " \n\n " );
283
283
@@ -396,6 +396,9 @@ int main(int argc, char ** argv) {
396
396
const float repeat_penalty = params.repeat_penalty ;
397
397
const float alpha_presence = params.alpha_presence ;
398
398
const float alpha_frequency = params.alpha_frequency ;
399
+ const int mirostat = params.mirostat ;
400
+ const float mirostat_tau = params.mirostat_tau ;
401
+ const float mirostat_eta = params.mirostat_eta ;
399
402
400
403
// optionally save the session on first sample (for faster prompt loading next time)
401
404
if (!path_session.empty () && need_to_save_session) {
@@ -415,47 +418,45 @@ int main(int argc, char ** argv) {
415
418
416
419
std::vector<llama_token_data> candidates;
417
420
candidates.reserve (n_vocab);
418
- for (size_t i = 0 ; i < n_vocab; i++) {
421
+ for (size_t i = 0 ; i < ( size_t ) n_vocab; i++) {
419
422
candidates.emplace_back (i, logits[i], 0 .0f );
420
423
}
421
424
422
- llama_token_data_array candidates_p = { candidates.data (), candidates.size () };
425
+ llama_token_data_array candidates_p = { candidates.data (), candidates.size (), false };
423
426
424
427
// Apply penalties
425
428
auto last_n_repeat = std::min (std::min ((int )last_n_tokens.size (), repeat_last_n), n_ctx);
426
- llama_sample_repetition_penalty (&candidates_p,
429
+ llama_sample_repetition_penalty (ctx, &candidates_p,
427
430
last_n_tokens.data () + last_n_tokens.size () - last_n_repeat,
428
431
last_n_repeat, repeat_penalty);
429
- llama_sample_frequency_and_presence_penalties (&candidates_p,
432
+ llama_sample_frequency_and_presence_penalties (ctx, &candidates_p,
430
433
last_n_tokens.data () + last_n_tokens.size () - last_n_repeat,
431
434
last_n_repeat, alpha_frequency, alpha_presence);
432
435
433
436
434
- #if 1
435
437
if (temp <= 0 ) {
436
438
// Greedy sampling
437
439
id = llama_sample_token_greedy (ctx, &candidates_p);
438
440
} else {
439
- // Temperature sampling
440
- llama_sample_top_k (&candidates_p, top_k);
441
- llama_sample_tail_free (&candidates_p, tfs_z);
442
- llama_sample_typical (&candidates_p, typical_p);
443
- llama_sample_top_p (&candidates_p, top_p);
444
-
445
- llama_sample_temperature (&candidates_p, temp);
446
- // printf("`%d`", candidates_p.size);
447
- id = llama_sample_token (ctx, &candidates_p);
441
+ if (mirostat == 1 ) {
442
+ static float mirostat_mu = 2 .0f * mirostat_tau;
443
+ static int mirostat_k = 40 ;
444
+ const int mirostat_m = 100 ;
445
+ id = llama_sample_token_mirostat (ctx, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, float (n_vocab), &mirostat_k, &mirostat_mu);
446
+ } else if (mirostat == 2 ) {
447
+ static float mirostat_mu = 2 .0f * mirostat_tau;
448
+ id = llama_sample_token_mirostat_v2 (ctx, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu);
449
+ } else {
450
+ // Temperature sampling
451
+ llama_sample_top_k (ctx, &candidates_p, top_k);
452
+ llama_sample_tail_free (ctx, &candidates_p, tfs_z);
453
+ llama_sample_typical (ctx, &candidates_p, typical_p);
454
+ llama_sample_top_p (ctx, &candidates_p, top_p);
455
+ llama_sample_temperature (ctx, &candidates_p, temp);
456
+ id = llama_sample_token (ctx, &candidates_p);
457
+ }
448
458
}
449
- #else
450
- const float tau = 5.0f;
451
- static float mu = 2.0f * tau;
452
- static int k = 40;
453
- const float eta = 0.1f;
454
- const int m = 100;
455
- const float N = n_vocab;
456
- id = llama_sample_mirostat(ctx, &candidates_p, tau, eta, m, N, &k, &mu);
457
- // id = llama_sample_mirostat_v2(ctx, &candidates_p, tau, eta, &mu);
458
- #endif
459
+ // printf("`%d`", candidates_p.size);
459
460
460
461
last_n_tokens.erase (last_n_tokens.begin ());
461
462
last_n_tokens.push_back (id);
0 commit comments