Skip to content

Commit fdc53e2

Browse files
committed
common : add llama_sample_token helper function
1 parent 260b4a5 commit fdc53e2

File tree

4 files changed

+189
-208
lines changed

4 files changed

+189
-208
lines changed

common/common.cpp

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -830,6 +830,130 @@ std::string llama_detokenize_bpe(llama_context * ctx, const std::vector<llama_to
830830
return result;
831831
}
832832

833+
//
834+
// Sampling utils
835+
//
836+
837+
llama_token llama_sample_token(
838+
struct llama_context * ctx,
839+
struct llama_context * ctx_guidance,
840+
struct llama_grammar * grammar,
841+
const struct gpt_params & params,
842+
const std::vector<llama_token> & last_tokens,
843+
std::vector<llama_token_data> & candidates,
844+
int idx) {
845+
const int n_ctx = llama_n_ctx(ctx);
846+
const int n_vocab = llama_n_vocab(ctx);
847+
848+
const float temp = params.temp;
849+
const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k;
850+
const float top_p = params.top_p;
851+
const float tfs_z = params.tfs_z;
852+
const float typical_p = params.typical_p;
853+
const int32_t repeat_last_n = params.repeat_last_n < 0 ? n_ctx : params.repeat_last_n;
854+
const float repeat_penalty = params.repeat_penalty;
855+
const float alpha_presence = params.presence_penalty;
856+
const float alpha_frequency = params.frequency_penalty;
857+
const int mirostat = params.mirostat;
858+
const float mirostat_tau = params.mirostat_tau;
859+
const float mirostat_eta = params.mirostat_eta;
860+
const bool penalize_nl = params.penalize_nl;
861+
862+
llama_token id = 0;
863+
864+
float * logits = llama_get_logits(ctx) + idx * n_vocab;
865+
866+
// Apply params.logit_bias map
867+
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
868+
logits[it->first] += it->second;
869+
}
870+
871+
candidates.clear();
872+
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
873+
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
874+
}
875+
876+
llama_token_data_array cur_p = { candidates.data(), candidates.size(), false };
877+
878+
if (ctx_guidance) {
879+
llama_sample_classifier_free_guidance(ctx, &cur_p, ctx_guidance, params.cfg_scale);
880+
}
881+
882+
// apply penalties
883+
if (!last_tokens.empty()) {
884+
const float nl_logit = logits[llama_token_nl(ctx)];
885+
const int last_n_repeat = std::min(std::min((int)last_tokens.size(), repeat_last_n), n_ctx);
886+
887+
llama_sample_repetition_penalty(ctx, &cur_p,
888+
last_tokens.data() + last_tokens.size() - last_n_repeat,
889+
last_n_repeat, repeat_penalty);
890+
llama_sample_frequency_and_presence_penalties(ctx, &cur_p,
891+
last_tokens.data() + last_tokens.size() - last_n_repeat,
892+
last_n_repeat, alpha_frequency, alpha_presence);
893+
894+
if (!penalize_nl) {
895+
for (size_t idx = 0; idx < cur_p.size; idx++) {
896+
if (cur_p.data[idx].id == llama_token_nl(ctx)) {
897+
cur_p.data[idx].logit = nl_logit;
898+
break;
899+
}
900+
}
901+
}
902+
}
903+
904+
if (grammar != NULL) {
905+
llama_sample_grammar(ctx, &cur_p, grammar);
906+
}
907+
908+
if (temp <= 0) {
909+
// Greedy sampling
910+
id = llama_sample_token_greedy(ctx, &cur_p);
911+
} else {
912+
if (mirostat == 1) {
913+
static float mirostat_mu = 2.0f * mirostat_tau;
914+
const int mirostat_m = 100;
915+
llama_sample_temperature(ctx, &cur_p, temp);
916+
id = llama_sample_token_mirostat(ctx, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
917+
} else if (mirostat == 2) {
918+
static float mirostat_mu = 2.0f * mirostat_tau;
919+
llama_sample_temperature(ctx, &cur_p, temp);
920+
id = llama_sample_token_mirostat_v2(ctx, &cur_p, mirostat_tau, mirostat_eta, &mirostat_mu);
921+
} else {
922+
// Temperature sampling
923+
llama_sample_top_k (ctx, &cur_p, top_k, 1);
924+
llama_sample_tail_free (ctx, &cur_p, tfs_z, 1);
925+
llama_sample_typical (ctx, &cur_p, typical_p, 1);
926+
llama_sample_top_p (ctx, &cur_p, top_p, 1);
927+
llama_sample_temperature(ctx, &cur_p, temp);
928+
929+
{
930+
const int n_top = 10;
931+
LOG("top %d candidates:\n", n_top);
932+
933+
for (int i = 0; i < n_top; i++) {
934+
const llama_token id = cur_p.data[i].id;
935+
LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(ctx, id).c_str(), cur_p.data[i].p);
936+
}
937+
}
938+
939+
id = llama_sample_token(ctx, &cur_p);
940+
941+
LOG("sampled token: %5d: '%s'\n", id, llama_token_to_piece(ctx, id).c_str());
942+
}
943+
}
944+
// printf("`%d`", candidates_p.size);
945+
946+
if (grammar != NULL) {
947+
llama_grammar_accept_token(ctx, grammar, id);
948+
}
949+
950+
return id;
951+
}
952+
953+
//
954+
// YAML utils
955+
//
956+
833957
// returns true if successful, false otherwise
834958
bool create_directory_with_parents(const std::string & path) {
835959
#ifdef _WIN32

common/common.h

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,40 @@ std::string llama_detokenize_bpe(
157157
llama_context * ctx,
158158
const std::vector<llama_token> & tokens);
159159

160+
//
161+
// Sampling utils
162+
//
163+
164+
// this is a common sampling function used across the examples for convenience
165+
// it can serve as a starting point for implementing your own sampling function
166+
//
167+
// required:
168+
// - ctx: context to use for sampling
169+
// - params: sampling parameters
170+
//
171+
// optional:
172+
// - ctx_guidance: context to use for classifier-free guidance, ignore if NULL
173+
// - grammar: grammar to use for sampling, ignore if NULL
174+
// - last_tokens: needed for repetition penalty, ignore if empty
175+
// - idx: sample from llama_get_logits(ctx) + idx * n_vocab
176+
//
177+
// returns:
178+
// - token: sampled token
179+
// - candidates: vector of candidate tokens
180+
//
181+
llama_token llama_sample_token(
182+
struct llama_context * ctx,
183+
struct llama_context * ctx_guidance,
184+
struct llama_grammar * grammar,
185+
const struct gpt_params & params,
186+
const std::vector<llama_token> & last_tokens,
187+
std::vector<llama_token_data> & candidates,
188+
int idx = 0);
189+
190+
//
191+
// YAML utils
192+
//
193+
160194
bool create_directory_with_parents(const std::string & path);
161195
void dump_vector_float_yaml(FILE * stream, const char * prop_name, const std::vector<float> & data);
162196
void dump_vector_int_yaml(FILE * stream, const char * prop_name, const std::vector<int> & data);

examples/main/main.cpp

Lines changed: 20 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -425,8 +425,9 @@ int main(int argc, char ** argv) {
425425
LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
426426
LOG_TEE("\n\n");
427427

428+
struct llama_grammar * grammar = NULL;
428429
grammar_parser::parse_state parsed_grammar;
429-
llama_grammar * grammar = NULL;
430+
430431
if (!params.grammar.empty()) {
431432
parsed_grammar = grammar_parser::parse(params.grammar.c_str());
432433
// will be empty (default) if there are parse errors
@@ -450,8 +451,8 @@ int main(int argc, char ** argv) {
450451
}
451452

452453
// TODO: replace with ring-buffer
453-
std::vector<llama_token> last_n_tokens(n_ctx);
454-
std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
454+
std::vector<llama_token> last_tokens(n_ctx);
455+
std::fill(last_tokens.begin(), last_tokens.end(), 0);
455456

456457
if (params.interactive) {
457458
const char *control_message;
@@ -500,6 +501,11 @@ int main(int argc, char ** argv) {
500501
llama_reset_timings(ctx);
501502
}
502503

504+
const int n_vocab = llama_n_vocab(ctx);
505+
506+
std::vector<llama_token_data> candidates;
507+
candidates.reserve(n_vocab);
508+
503509
while ((n_remain != 0 && !is_antiprompt) || params.interactive) {
504510
// predict
505511
if (embd.size() > 0) {
@@ -537,8 +543,8 @@ int main(int argc, char ** argv) {
537543

538544
LOG("after swap: n_past = %d, n_past_guidance = %d\n", n_past, n_past_guidance);
539545

540-
// insert n_left/2 tokens at the start of embd from last_n_tokens
541-
embd.insert(embd.begin(), last_n_tokens.begin() + n_ctx - n_left/2 - embd.size(), last_n_tokens.end() - embd.size());
546+
// insert n_left/2 tokens at the start of embd from last_tokens
547+
embd.insert(embd.begin(), last_tokens.begin() + n_ctx - n_left/2 - embd.size(), last_tokens.end() - embd.size());
542548

543549
LOG("embd: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd));
544550

@@ -637,20 +643,6 @@ int main(int argc, char ** argv) {
637643
embd_guidance.clear();
638644

639645
if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
640-
const float temp = params.temp;
641-
const int32_t top_k = params.top_k <= 0 ? llama_n_vocab(ctx) : params.top_k;
642-
const float top_p = params.top_p;
643-
const float tfs_z = params.tfs_z;
644-
const float typical_p = params.typical_p;
645-
const int32_t repeat_last_n = params.repeat_last_n < 0 ? n_ctx : params.repeat_last_n;
646-
const float repeat_penalty = params.repeat_penalty;
647-
const float alpha_presence = params.presence_penalty;
648-
const float alpha_frequency = params.frequency_penalty;
649-
const int mirostat = params.mirostat;
650-
const float mirostat_tau = params.mirostat_tau;
651-
const float mirostat_eta = params.mirostat_eta;
652-
const bool penalize_nl = params.penalize_nl;
653-
654646
// optionally save the session on first sample (for faster prompt loading next time)
655647
if (!path_session.empty() && need_to_save_session && !params.prompt_cache_ro) {
656648
need_to_save_session = false;
@@ -659,98 +651,12 @@ int main(int argc, char ** argv) {
659651
LOG("saved session to %s\n", path_session.c_str());
660652
}
661653

662-
llama_token id = 0;
663-
664-
{
665-
auto logits = llama_get_logits(ctx);
666-
auto n_vocab = llama_n_vocab(ctx);
667-
668-
// Apply params.logit_bias map
669-
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
670-
logits[it->first] += it->second;
671-
}
672-
673-
std::vector<llama_token_data> candidates;
674-
candidates.reserve(n_vocab);
675-
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
676-
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
677-
}
678-
679-
llama_token_data_array cur_p = { candidates.data(), candidates.size(), false };
680-
681-
if (ctx_guidance) {
682-
llama_sample_classifier_free_guidance(ctx, &cur_p, ctx_guidance, params.cfg_scale);
683-
}
684-
685-
// Apply penalties
686-
float nl_logit = logits[llama_token_nl(ctx)];
687-
auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), n_ctx);
688-
llama_sample_repetition_penalty(ctx, &cur_p,
689-
last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
690-
last_n_repeat, repeat_penalty);
691-
llama_sample_frequency_and_presence_penalties(ctx, &cur_p,
692-
last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
693-
last_n_repeat, alpha_frequency, alpha_presence);
694-
if (!penalize_nl) {
695-
for (size_t idx = 0; idx < cur_p.size; idx++) {
696-
if (cur_p.data[idx].id == llama_token_nl(ctx)) {
697-
cur_p.data[idx].logit = nl_logit;
698-
break;
699-
}
700-
}
701-
}
702-
703-
if (grammar != NULL) {
704-
llama_sample_grammar(ctx, &cur_p, grammar);
705-
}
706-
707-
if (temp <= 0) {
708-
// Greedy sampling
709-
id = llama_sample_token_greedy(ctx, &cur_p);
710-
} else {
711-
if (mirostat == 1) {
712-
static float mirostat_mu = 2.0f * mirostat_tau;
713-
const int mirostat_m = 100;
714-
llama_sample_temperature(ctx, &cur_p, temp);
715-
id = llama_sample_token_mirostat(ctx, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
716-
} else if (mirostat == 2) {
717-
static float mirostat_mu = 2.0f * mirostat_tau;
718-
llama_sample_temperature(ctx, &cur_p, temp);
719-
id = llama_sample_token_mirostat_v2(ctx, &cur_p, mirostat_tau, mirostat_eta, &mirostat_mu);
720-
} else {
721-
// Temperature sampling
722-
llama_sample_top_k (ctx, &cur_p, top_k, 1);
723-
llama_sample_tail_free (ctx, &cur_p, tfs_z, 1);
724-
llama_sample_typical (ctx, &cur_p, typical_p, 1);
725-
llama_sample_top_p (ctx, &cur_p, top_p, 1);
726-
llama_sample_temperature(ctx, &cur_p, temp);
727-
728-
{
729-
const int n_top = 10;
730-
LOG("top %d candidates:\n", n_top);
731-
732-
for (int i = 0; i < n_top; i++) {
733-
const llama_token id = cur_p.data[i].id;
734-
LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(ctx, id).c_str(), cur_p.data[i].p);
735-
}
736-
}
737-
738-
id = llama_sample_token(ctx, &cur_p);
654+
const llama_token id = llama_sample_token(ctx, ctx_guidance, grammar, params, last_tokens, candidates);
739655

740-
LOG("sampled token: %5d: '%s'\n", id, llama_token_to_piece(ctx, id).c_str());
741-
}
742-
}
743-
// printf("`%d`", candidates_p.size);
656+
last_tokens.erase(last_tokens.begin());
657+
last_tokens.push_back(id);
744658

745-
if (grammar != NULL) {
746-
llama_grammar_accept_token(ctx, grammar, id);
747-
}
748-
749-
last_n_tokens.erase(last_n_tokens.begin());
750-
last_n_tokens.push_back(id);
751-
752-
LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, last_n_tokens));
753-
}
659+
LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, last_tokens));
754660

755661
embd.push_back(id);
756662

@@ -766,8 +672,8 @@ int main(int argc, char ** argv) {
766672
LOG("embd_inp.size(): %d, n_consumed: %d\n", (int) embd_inp.size(), n_consumed);
767673
while ((int) embd_inp.size() > n_consumed) {
768674
embd.push_back(embd_inp[n_consumed]);
769-
last_n_tokens.erase(last_n_tokens.begin());
770-
last_n_tokens.push_back(embd_inp[n_consumed]);
675+
last_tokens.erase(last_tokens.begin());
676+
last_tokens.push_back(embd_inp[n_consumed]);
771677
++n_consumed;
772678
if ((int) embd.size() >= params.n_batch) {
773679
break;
@@ -800,7 +706,7 @@ int main(int argc, char ** argv) {
800706
// check for reverse prompt
801707
if (params.antiprompt.size()) {
802708
std::string last_output;
803-
for (auto id : last_n_tokens) {
709+
for (auto id : last_tokens) {
804710
last_output += llama_token_to_piece(ctx, id);
805711
}
806712

@@ -831,7 +737,7 @@ int main(int argc, char ** argv) {
831737
}
832738

833739
// deal with end of text token in interactive mode
834-
if (last_n_tokens.back() == llama_token_eos(ctx)) {
740+
if (last_tokens.back() == llama_token_eos(ctx)) {
835741
LOG("found EOS token\n");
836742

837743
if (params.interactive) {
@@ -933,7 +839,7 @@ int main(int argc, char ** argv) {
933839
if (grammar != NULL) {
934840
llama_grammar_free(grammar);
935841

936-
std::vector<const llama_grammar_element *> grammar_rules( parsed_grammar.c_rules());
842+
std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
937843
grammar = llama_grammar_init(
938844
grammar_rules.data(), grammar_rules.size(),
939845
parsed_grammar.symbol_ids.at("root"));

0 commit comments

Comments
 (0)