Skip to content

llama : infill sampling handle very long tokens #9924

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -953,12 +953,6 @@ extern "C" {
int32_t lstrip,
bool special);

// check if token0 is contained as a prefix in token1
LLAMA_API bool llama_token_is_prefix(
const struct llama_model * model,
llama_token token0,
llama_token token1);

/// @details Convert the provided tokens into text (inverse of llama_tokenize()).
/// @param text The char pointer must be large enough to hold the resulting text.
/// @return Returns the number of chars/bytes on success, no more than text_len_max.
Expand Down
48 changes: 35 additions & 13 deletions src/llama-sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1745,6 +1745,9 @@ struct llama_sampler * llama_sampler_init_logit_bias(

struct llama_sampler_infill {
const struct llama_vocab * vocab;

std::vector<char> buf0;
std::vector<char> buf1;
};

static const char * llama_sampler_infill_name(const struct llama_sampler * /*smpl*/) {
Expand Down Expand Up @@ -1810,27 +1813,44 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
size_t n_combined = 0; GGML_UNUSED(n_combined);

// combine tokens with common prefix
for (size_t i = 0; i < cur_p->size; ++i) {
for (size_t j = 0; j < cur_p->size; ++j) {
if (cur_p->data[i].logit == -INFINITY) {
for (size_t i0 = 0; i0 < cur_p->size; ++i0) {
for (size_t i1 = 0; i1 < cur_p->size; ++i1) {
if (cur_p->data[i0].logit == -INFINITY) {
break;
}

if (i == j || cur_p->data[j].logit == -INFINITY) {
if (i0 == i1 || cur_p->data[i1].logit == -INFINITY) {
continue;
}

if (llama_token_is_prefix_impl(*ctx->vocab, cur_p->data[i].id, cur_p->data[j].id)) {
if (cur_p->data[i].p > cur_p->data[j].p) {
cur_p->data[i].p += cur_p->data[j].p;
cur_p->data[j].logit = -INFINITY;
cur_p->data[j].p = 0.0f;
} else {
cur_p->data[j].p += cur_p->data[i].p;
cur_p->data[i].logit = -INFINITY;
cur_p->data[i].p = 0.0f;
int len0 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i0].id, ctx->buf0.data(), ctx->buf0.size(), 0, false);
if (len0 < 0) {
ctx->buf0.resize(len0);
len0 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i0].id, ctx->buf0.data(), ctx->buf0.size(), 0, false);
assert(len0 > 0);
}

int len1 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i1].id, ctx->buf1.data(), ctx->buf1.size(), 0, false);
if (len1 < 0) {
ctx->buf1.resize(len1);
len1 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i1].id, ctx->buf1.data(), ctx->buf1.size(), 0, false);
assert(len1 > 0);
}

// token i0 is a prefix of token i1
if (len0 > 0 && len0 <= len1 && memcmp(ctx->buf0.data(), ctx->buf1.data(), len0) == 0) {
int dst = i0;
int src = i1;

// merge into the token with higher probability
if (cur_p->data[i1].p > cur_p->data[i0].p) {
std::swap(dst, src);
}

cur_p->data[dst].p += cur_p->data[src].p;
cur_p->data[src].logit = -INFINITY;
cur_p->data[src].p = 0.0f;

n_combined++;
}
}
Expand Down Expand Up @@ -1936,6 +1956,8 @@ struct llama_sampler * llama_sampler_init_infill_impl(
/* .iface = */ &llama_sampler_infill_i,
/* .ctx = */ new llama_sampler_infill {
/* .vocab = */ &vocab,
/* .buf0 = */ std::vector<char>(512),
/* .buf1 = */ std::vector<char>(512),
},
};
}
Expand Down
17 changes: 0 additions & 17 deletions src/llama-vocab.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1858,23 +1858,6 @@ int32_t llama_token_to_piece_impl(const struct llama_vocab & vocab, llama_token
return 0;
}

bool llama_token_is_prefix_impl(
const struct llama_vocab & vocab,
llama_token token0,
llama_token token1) {
char text_buf_0[128];
char text_buf_1[128];

const int32_t len0 = llama_token_to_piece_impl(vocab, token0, text_buf_0, sizeof(text_buf_0) - 1, 0, false);
const int32_t len1 = llama_token_to_piece_impl(vocab, token1, text_buf_1, sizeof(text_buf_1) - 1, 0, false);

if (len0 <= 0 || len1 <= 0) {
return false;
}

return len0 <= len1 && memcmp(text_buf_0, text_buf_1, len0) == 0;
}

int32_t llama_detokenize_impl(
const struct llama_vocab & vocab,
const llama_token * tokens,
Expand Down
7 changes: 0 additions & 7 deletions src/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21466,13 +21466,6 @@ int32_t llama_token_to_piece(
return llama_token_to_piece_impl(model->vocab, token, buf, length, lstrip, special);
}

bool llama_token_is_prefix(
const struct llama_model * model,
llama_token token0,
llama_token token1) {
return llama_token_is_prefix_impl(model->vocab, token0, token1);
}

int32_t llama_detokenize(
const struct llama_model * model,
const llama_token * tokens,
Expand Down
Loading