Skip to content

Commit 8f62e5d

Browse files
lookup: hashmap, most frequent tokens, abort early
1 parent 76e8688 commit 8f62e5d

File tree

1 file changed

+142
-28
lines changed

1 file changed

+142
-28
lines changed

examples/lookup/lookup.cpp

Lines changed: 142 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,22 @@
77
#include <cstdio>
88
#include <string>
99
#include <vector>
10+
#include <unordered_map>
11+
12+
// Data structures to map n-grams to empirical token probabilities:
13+
typedef std::unordered_map<llama_token, int> token_hashmap; // token -> number of times token has been seen
14+
typedef std::unordered_map<uint64_t, token_hashmap> all_token_hashmap; // n-gram -> empirical distribution of following tokens
15+
// n-grams are encoded as 64 bit integers with each of the 4 16 bit sections representing a token id.
16+
// This way no custom hashing function for the n-grams is needed.
17+
18+
// Min/max n-gram size to search for in prompt:
19+
constexpr int ngram_min = 1;
20+
constexpr int ngram_max = 4;
21+
static_assert(ngram_max <= sizeof(uint64_t)/2, "A 64 bit integer can only hold information for 4 16 bit tokens.");
22+
23+
// If sample size or percentage in context are below these thresholds the draft is aborted early:
24+
constexpr float draft_min_sample_size[ngram_max] = { 2, 2, 1, 1};
25+
constexpr float draft_min_percent[ngram_max] = {66, 50, 50, 50};
1026

1127
int main(int argc, char ** argv){
1228
gpt_params params;
@@ -16,9 +32,6 @@ int main(int argc, char ** argv){
1632
}
1733

1834
// max/min n-grams size to search for in prompt
19-
const int ngram_max = 4;
20-
const int ngram_min = 1;
21-
2235
// length of the candidate / draft sequence, if match is found
2336
const int n_draft = params.n_draft;
2437

@@ -39,6 +52,7 @@ int main(int argc, char ** argv){
3952

4053
// load the model
4154
std::tie(model, ctx) = llama_init_from_gpt_params(params);
55+
GGML_ASSERT(llama_n_vocab(model) < (1 << 16));
4256

4357
// tokenize the prompt
4458
const bool add_bos = llama_should_add_bos_token(model);
@@ -47,6 +61,55 @@ int main(int argc, char ** argv){
4761
std::vector<llama_token> inp;
4862
inp = ::llama_tokenize(ctx, params.prompt, add_bos, true);
4963

64+
auto update_hashmaps = [](all_token_hashmap * atcs, const llama_token * inp_data, const int inp_size, const int nnew) -> void {
65+
// atcs = all_token_counts: the hashmaps to modify.
66+
// inp_data: the token sequence on which the hashmaps are based.
67+
// inp_size: the current size of inp_data.
68+
// nnew: how many new tokens have been appended to inp_data since the last call to this function.
69+
//
70+
// In order to get correct results inp_data can ONLY BE APPENDED TO.
71+
// Changes in the middle need a complete rebuild.
72+
for (int ngram_size = ngram_min; ngram_size <= ngram_max; ++ngram_size) {
73+
all_token_hashmap * atc = atcs + ngram_size - ngram_min;
74+
75+
const int i_start = std::max(inp_size - nnew, ngram_size);
76+
for (int i = i_start; i < inp_size; ++i) {
77+
const int ngram_start = i - ngram_size;
78+
uint64_t ngram = inp_data[ngram_start];
79+
for (int j = ngram_start; j < ngram_start + ngram_size; ++j) {
80+
const uint64_t ngram_part = inp_data[j];
81+
ngram <<= 16;
82+
ngram |= ngram_part;
83+
}
84+
const llama_token token = inp_data[i];
85+
86+
all_token_hashmap::iterator token_counts_it = atc->find(ngram);
87+
if (token_counts_it == atc->end()) {
88+
token_hashmap token_counts;
89+
token_counts.emplace(token, 1);
90+
atc->emplace(ngram, token_counts);
91+
} else {
92+
token_hashmap::iterator tc_it = token_counts_it->second.find(token);
93+
if (tc_it == token_counts_it->second.end()) {
94+
token_counts_it->second.emplace(token, 1);
95+
} else {
96+
tc_it->second++;
97+
}
98+
}
99+
}
100+
}
101+
};
102+
103+
all_token_hashmap all_token_counts[ngram_max-ngram_min+1];
104+
int64_t t_draft_us = 0;
105+
106+
{
107+
// Fill up hashmaps with tokens from user input:
108+
const int64_t t_start_draft_us = ggml_time_us();
109+
update_hashmaps(all_token_counts, inp.data(), inp.size(), inp.size());
110+
t_draft_us += ggml_time_us() - t_start_draft_us;
111+
}
112+
50113
const int max_context_size = llama_n_ctx(ctx);
51114
const int max_tokens_list_size = max_context_size - 4;
52115

@@ -76,8 +139,6 @@ int main(int argc, char ** argv){
76139
int n_drafted = 0;
77140
int n_accept = 0;
78141

79-
int64_t t_draft_us = 0;
80-
81142
int n_past = inp.size();
82143

83144
bool has_eos = false;
@@ -129,6 +190,12 @@ int main(int argc, char ** argv){
129190
++n_past;
130191
++i_dft;
131192
inp.push_back(id);
193+
{
194+
// Update hashmaps with the newly accepted token:
195+
const int64_t t_start_draft_us = ggml_time_us();
196+
update_hashmaps(all_token_counts, inp.data(), inp.size(), 1);
197+
t_draft_us += ggml_time_us() - t_start_draft_us;
198+
}
132199

133200
if (params.use_color) {
134201
// color accepted draft token
@@ -149,6 +216,12 @@ int main(int argc, char ** argv){
149216
draft.clear();
150217
draft.push_back(id);
151218
inp.push_back(id);
219+
{
220+
// Update hashmaps with the newly accepted token:
221+
const int64_t t_start_draft_us = ggml_time_us();
222+
update_hashmaps(all_token_counts, inp.data(), inp.size(), 1);
223+
t_draft_us += ggml_time_us() - t_start_draft_us;
224+
}
152225
break;
153226
}
154227

@@ -163,44 +236,85 @@ int main(int argc, char ** argv){
163236
llama_batch_clear(batch_tgt);
164237
llama_batch_add(batch_tgt, draft[0], n_past, { 0 }, true);
165238

166-
// generate n_pred tokens through prompt lookup
239+
auto get_token = [](const std::vector<llama_token> inp, const std::vector<llama_token> draft, const size_t i) -> llama_token {
240+
// Helper function to get a token from the combined, speculative sequence of inp and draft.
241+
return i < inp.size() ? inp[i] : draft[1 + i - inp.size()];
242+
};
243+
167244
auto prompt_lookup = [&]() -> void {
245+
// Generate up to n_draft additional tokens through prompt lookup.
246+
// The draft is aborted early if there is no suitable token candidate to continue the draft.
247+
// At the beginning of this function the draft already contains a single token sampled from the model.
168248
const int inp_size = inp.size();
169-
for (int ngram_size = ngram_max ; ngram_size > ngram_min; --ngram_size){
170-
const llama_token * ngram = &inp[inp_size - ngram_size];
171-
172-
for (int i = 0; i <= (int) inp_size - (ngram_size * 2); ++i) {
173-
bool match = true;
174-
for (int j = 0; j < ngram_size; ++j) {
175-
if (inp[i + j] != ngram[j]) {
176-
match = false;
177-
break;
178-
}
249+
250+
while ((int) draft.size()-1 < n_draft) {
251+
bool draft_success = false;
252+
for (int ngram_size = ngram_max; ngram_size >= ngram_min; --ngram_size) {
253+
if (ngram_size > inp_size) {
254+
continue;
255+
}
256+
257+
all_token_hashmap & atc = all_token_counts[ngram_size - ngram_min];
258+
259+
const int ngram_start = inp_size-ngram_size + draft.size()-1;
260+
uint64_t ngram = get_token(inp, draft, ngram_start);
261+
for (int j = ngram_start; j < ngram_start + ngram_size; ++j) {
262+
const uint64_t ngram_part = get_token(inp, draft, j);
263+
ngram <<= 16;
264+
ngram |= ngram_part;
179265
}
180266

181-
if (match) {
182-
const int startIdx = i + ngram_size;
183-
const int endIdx = startIdx + n_draft;
184-
if (endIdx < inp_size) {
185-
for (int j = startIdx; j < endIdx; ++j) {
186-
LOG(" - draft candidate %d: %d\n", j, inp[j]);
187-
draft.push_back(inp[j]);
188-
llama_batch_add(batch_tgt, inp[j], n_past + (j - startIdx) + 1, { 0 }, true);
189-
++n_drafted;
190-
}
191-
return;
267+
all_token_hashmap::iterator token_counts_it = atc.find(ngram);
268+
if (token_counts_it == atc.end()) {
269+
continue;
270+
}
271+
const token_hashmap token_counts = token_counts_it->second;
272+
273+
int max_count = 0;
274+
int sum_count = 0;
275+
llama_token max_token = -1;
276+
277+
for (std::pair<llama_token, int> tc : token_counts) {
278+
const llama_token token = tc.first;
279+
const llama_token count = tc.second;
280+
281+
if (count > max_count) {
282+
max_token = token;
283+
max_count = count;
192284
}
285+
sum_count += count;
286+
}
287+
// Skip this candidate if the sample size is too low:
288+
if (sum_count < draft_min_sample_size[ngram_size-1]) {
289+
continue;
193290
}
291+
// skip this candidate if the empirically most likely token following this token is not likely enough:
292+
if (100*max_count < draft_min_percent[ngram_size-1]*sum_count) {
293+
continue;
294+
}
295+
296+
LOG(" - draft candidate: token=%d count=%d\n", max_token, max_count);
297+
llama_batch_add(batch_tgt, max_token, n_past + draft.size(), { 0 }, true);
298+
draft.push_back(max_token);
299+
draft_success = true;
300+
break;
301+
}
302+
303+
if (!draft_success) {
304+
break;
194305
}
195306
}
196-
return;
197307
};
198308

309+
// Draft already contains a single token sampled from the model:
310+
GGML_ASSERT(draft.size() == 1);
311+
GGML_ASSERT(draft[0] == inp.back());
199312
const int64_t t_start_draft_us = ggml_time_us();
200313

201314
prompt_lookup();
202315

203316
t_draft_us += ggml_time_us() - t_start_draft_us;
317+
n_drafted += draft.size() - 1;
204318

205319
llama_decode(ctx, batch_tgt);
206320
++n_past;

0 commit comments

Comments
 (0)