Skip to content

Add option to ignore tokens with 2+ English characters #8279

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ struct gpt_params {
float yarn_beta_slow = 1.0f; // YaRN high correction dim
int32_t yarn_orig_ctx = 0; // YaRN original context length
float defrag_thold = -1.0f; // KV cache defragmentation threshold

bool ignore_english_tokens = false; // Experimental: Attempt to not sample tokens containing English characters

ggml_backend_sched_eval_callback cb_eval = nullptr;
void * cb_eval_user_data = nullptr;

Expand Down
30 changes: 28 additions & 2 deletions common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,10 +121,12 @@ std::string llama_sampling_print(const llama_sampling_params & params) {
snprintf(result, sizeof(result),
"\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
"\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, min_p = %.3f, typical_p = %.3f, temp = %.3f\n"
"\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f",
"\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f\n"
"\tignore_english_tokens = %s",
params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present,
params.top_k, params.tfs_z, params.top_p, params.min_p, params.typical_p, params.temp,
params.mirostat, params.mirostat_eta, params.mirostat_tau);
params.mirostat, params.mirostat_eta, params.mirostat_tau,
params.ignore_english_tokens ? "true" : "false");

return std::string(result);
}
Expand Down Expand Up @@ -423,6 +425,13 @@ static llama_token_data_array llama_sampling_prepare_impl(
llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar);
}

if (params.ignore_english_tokens) {
for (size_t i = 0; i < cur_p.size; ++i) {
if (is_english_token(ctx_main, cur_p.data[i].id)) {
cur_p.data[i].logit = -INFINITY;
}
}

return cur_p;
}

Expand Down Expand Up @@ -457,3 +466,20 @@ void llama_sampling_accept(
llama_grammar_accept_token(ctx_main, ctx_sampling->grammar, id);
}
}

bool is_english_token(const llama_context * ctx, llama_token token) {
const std::string token_str = llama_token_to_piece(ctx, token);
int english_char_count = 0;
bool has_angle_bracket = false;

for (char c : token_str) {
if (c >= 'a' && c <= 'z') {
Copy link
Collaborator

@HanClinto HanClinto Jul 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that grammars are the proper way to address this task, but quick comment to note that if one is going to take this approach, then one should probably also check for upper-case characters as well.

Suggested change
if (c >= 'a' && c <= 'z') {
if ((c >= 'a' && c <= 'z') ||
(c >= 'A' && c <= 'Z')) {

Edit: Just realized that the grammar file that hopto-dot created already takes care of this. Nevermind me! :)

english_char_count++;
}
if (c == '<' || c == '>') {
has_angle_bracket = true;
}
}

return english_char_count >= 2 && !has_angle_bracket;
}
1 change: 1 addition & 0 deletions common/sampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ typedef struct llama_sampling_params {
float mirostat_eta = 0.10f; // learning rate
bool penalize_nl = false; // consider newlines as a repeatable token
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling_context
bool ignore_english_tokens = false; // Ignore tokens with 3+ English characters (except those with angle brackets)

std::vector<llama_sampler_type> samplers_sequence = {
llama_sampler_type::TOP_K,
Expand Down