Skip to content

lookup: use hashmaps, select most frequent tokens, abort draft early if no good candidates #5462

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 142 additions & 28 deletions examples/lookup/lookup.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,22 @@
#include <cstdio>
#include <string>
#include <vector>
#include <unordered_map>

// Data structures to map n-grams to empirical token probabilities:
typedef std::unordered_map<llama_token, int> token_hashmap; // token -> number of times token has been seen
typedef std::unordered_map<uint64_t, token_hashmap> all_token_hashmap; // n-gram -> empirical distribution of following tokens
// n-grams are encoded as 64 bit integers with each of the 4 16 bit sections representing a token id.
// This way no custom hashing function for the n-grams is needed.

// Min/max n-gram size to search for in prompt:
constexpr int ngram_min = 1;
constexpr int ngram_max = 4;
static_assert(ngram_max <= sizeof(uint64_t)/2, "A 64 bit integer can only hold information for 4 16 bit tokens.");

// If sample size or percentage in context are below these thresholds the draft is aborted early:
constexpr float draft_min_sample_size[ngram_max] = { 2, 2, 1, 1};
constexpr float draft_min_percent[ngram_max] = {66, 50, 50, 50};

int main(int argc, char ** argv){
gpt_params params;
Expand All @@ -16,9 +32,6 @@ int main(int argc, char ** argv){
}

// max/min n-grams size to search for in prompt
const int ngram_max = 4;
const int ngram_min = 1;

// length of the candidate / draft sequence, if match is found
const int n_draft = params.n_draft;

Expand All @@ -38,6 +51,7 @@ int main(int argc, char ** argv){

// load the model
std::tie(model, ctx) = llama_init_from_gpt_params(params);
GGML_ASSERT(llama_n_vocab(model) < (1 << 16));

// tokenize the prompt
const bool add_bos = llama_should_add_bos_token(model);
Expand All @@ -46,6 +60,55 @@ int main(int argc, char ** argv){
std::vector<llama_token> inp;
inp = ::llama_tokenize(ctx, params.prompt, add_bos, true);

auto update_hashmaps = [](all_token_hashmap * atcs, const llama_token * inp_data, const int inp_size, const int nnew) -> void {
// atcs = all_token_counts: the hashmaps to modify.
// inp_data: the token sequence on which the hashmaps are based.
// inp_size: the current size of inp_data.
// nnew: how many new tokens have been appended to inp_data since the last call to this function.
//
// In order to get correct results inp_data can ONLY BE APPENDED TO.
// Changes in the middle need a complete rebuild.
for (int ngram_size = ngram_min; ngram_size <= ngram_max; ++ngram_size) {
all_token_hashmap * atc = atcs + ngram_size - ngram_min;

const int i_start = std::max(inp_size - nnew, ngram_size);
for (int i = i_start; i < inp_size; ++i) {
const int ngram_start = i - ngram_size;
uint64_t ngram = inp_data[ngram_start];
for (int j = ngram_start; j < ngram_start + ngram_size; ++j) {
const uint64_t ngram_part = inp_data[j];
ngram <<= 16;
ngram |= ngram_part;
}
const llama_token token = inp_data[i];

all_token_hashmap::iterator token_counts_it = atc->find(ngram);
if (token_counts_it == atc->end()) {
token_hashmap token_counts;
token_counts.emplace(token, 1);
atc->emplace(ngram, token_counts);
} else {
token_hashmap::iterator tc_it = token_counts_it->second.find(token);
if (tc_it == token_counts_it->second.end()) {
token_counts_it->second.emplace(token, 1);
} else {
tc_it->second++;
}
}
}
}
};

all_token_hashmap all_token_counts[ngram_max-ngram_min+1];
int64_t t_draft_us = 0;

{
// Fill up hashmaps with tokens from user input:
const int64_t t_start_draft_us = ggml_time_us();
update_hashmaps(all_token_counts, inp.data(), inp.size(), inp.size());
t_draft_us += ggml_time_us() - t_start_draft_us;
}

const int max_context_size = llama_n_ctx(ctx);
const int max_tokens_list_size = max_context_size - 4;

Expand Down Expand Up @@ -75,8 +138,6 @@ int main(int argc, char ** argv){
int n_drafted = 0;
int n_accept = 0;

int64_t t_draft_us = 0;

int n_past = inp.size();

bool has_eos = false;
Expand Down Expand Up @@ -128,6 +189,12 @@ int main(int argc, char ** argv){
++n_past;
++i_dft;
inp.push_back(id);
{
// Update hashmaps with the newly accepted token:
const int64_t t_start_draft_us = ggml_time_us();
update_hashmaps(all_token_counts, inp.data(), inp.size(), 1);
t_draft_us += ggml_time_us() - t_start_draft_us;
}

if (params.use_color) {
// color accepted draft token
Expand All @@ -148,6 +215,12 @@ int main(int argc, char ** argv){
draft.clear();
draft.push_back(id);
inp.push_back(id);
{
// Update hashmaps with the newly accepted token:
const int64_t t_start_draft_us = ggml_time_us();
update_hashmaps(all_token_counts, inp.data(), inp.size(), 1);
t_draft_us += ggml_time_us() - t_start_draft_us;
}
break;
}

Expand All @@ -162,44 +235,85 @@ int main(int argc, char ** argv){
llama_batch_clear(batch_tgt);
llama_batch_add(batch_tgt, draft[0], n_past, { 0 }, true);

// generate n_pred tokens through prompt lookup
auto get_token = [](const std::vector<llama_token> inp, const std::vector<llama_token> draft, const size_t i) -> llama_token {
// Helper function to get a token from the combined, speculative sequence of inp and draft.
return i < inp.size() ? inp[i] : draft[1 + i - inp.size()];
};

auto prompt_lookup = [&]() -> void {
// Generate up to n_draft additional tokens through prompt lookup.
// The draft is aborted early if there is no suitable token candidate to continue the draft.
// At the beginning of this function the draft already contains a single token sampled from the model.
const int inp_size = inp.size();
for (int ngram_size = ngram_max ; ngram_size > ngram_min; --ngram_size){
const llama_token * ngram = &inp[inp_size - ngram_size];

for (int i = 0; i <= (int) inp_size - (ngram_size * 2); ++i) {
bool match = true;
for (int j = 0; j < ngram_size; ++j) {
if (inp[i + j] != ngram[j]) {
match = false;
break;
}

while ((int) draft.size()-1 < n_draft) {
bool draft_success = false;
for (int ngram_size = ngram_max; ngram_size >= ngram_min; --ngram_size) {
if (ngram_size > inp_size) {
continue;
}

all_token_hashmap & atc = all_token_counts[ngram_size - ngram_min];

const int ngram_start = inp_size-ngram_size + draft.size()-1;
uint64_t ngram = get_token(inp, draft, ngram_start);
for (int j = ngram_start; j < ngram_start + ngram_size; ++j) {
const uint64_t ngram_part = get_token(inp, draft, j);
ngram <<= 16;
ngram |= ngram_part;
}

if (match) {
const int startIdx = i + ngram_size;
const int endIdx = startIdx + n_draft;
if (endIdx < inp_size) {
for (int j = startIdx; j < endIdx; ++j) {
LOG(" - draft candidate %d: %d\n", j, inp[j]);
draft.push_back(inp[j]);
llama_batch_add(batch_tgt, inp[j], n_past + (j - startIdx) + 1, { 0 }, true);
++n_drafted;
}
return;
all_token_hashmap::iterator token_counts_it = atc.find(ngram);
if (token_counts_it == atc.end()) {
continue;
}
const token_hashmap token_counts = token_counts_it->second;

int max_count = 0;
int sum_count = 0;
llama_token max_token = -1;

for (std::pair<llama_token, int> tc : token_counts) {
const llama_token token = tc.first;
const llama_token count = tc.second;

if (count > max_count) {
max_token = token;
max_count = count;
}
sum_count += count;
}
// Skip this candidate if the sample size is too low:
if (sum_count < draft_min_sample_size[ngram_size-1]) {
continue;
}
// skip this candidate if the empirically most likely token following this token is not likely enough:
if (100*max_count < draft_min_percent[ngram_size-1]*sum_count) {
continue;
}

LOG(" - draft candidate: token=%d count=%d\n", max_token, max_count);
llama_batch_add(batch_tgt, max_token, n_past + draft.size(), { 0 }, true);
draft.push_back(max_token);
draft_success = true;
break;
}

if (!draft_success) {
break;
}
}
return;
};

// Draft already contains a single token sampled from the model:
GGML_ASSERT(draft.size() == 1);
GGML_ASSERT(draft[0] == inp.back());
const int64_t t_start_draft_us = ggml_time_us();

prompt_lookup();

t_draft_us += ggml_time_us() - t_start_draft_us;
n_drafted += draft.size() - 1;

llama_decode(ctx, batch_tgt);
++n_past;
Expand Down