Skip to content

Commit e3d8c5e

Browse files
lookup: hashmap, most frequent tokens, abort early
1 parent 4a46d2b commit e3d8c5e

File tree

1 file changed

+142
-28
lines changed

1 file changed

+142
-28
lines changed

examples/lookup/lookup.cpp

Lines changed: 142 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,22 @@
77
#include <cstdio>
88
#include <string>
99
#include <vector>
10+
#include <unordered_map>
11+
12+
// Data structures to map n-grams to empirical token probabilities:
13+
typedef std::unordered_map<llama_token, int> token_hashmap; // token -> number of times token has been seen
14+
typedef std::unordered_map<uint64_t, token_hashmap> all_token_hashmap; // n-gram -> empirical distribution of following tokens
15+
// n-grams are encoded as 64 bit integers with each of the 4 16 bit sections representing a token id.
16+
// This way no custom hashing function for the n-grams is needed.
17+
18+
// Min/max n-gram size to search for in prompt:
19+
constexpr int ngram_min = 1;
20+
constexpr int ngram_max = 4;
21+
static_assert(ngram_max <= sizeof(uint64_t)/2, "A 64 bit integer can only hold information for 4 16 bit tokens.");
22+
23+
// If sample size or percentage in context are below these thresholds the draft is aborted early:
24+
constexpr float draft_min_sample_size[ngram_max] = { 2, 2, 1, 1};
25+
constexpr float draft_min_percent[ngram_max] = {66, 50, 50, 50};
1026

1127
int main(int argc, char ** argv){
1228
gpt_params params;
@@ -16,9 +32,6 @@ int main(int argc, char ** argv){
1632
}
1733

1834
// max/min n-grams size to search for in prompt
19-
const int ngram_max = 4;
20-
const int ngram_min = 1;
21-
2235
// length of the candidate / draft sequence, if match is found
2336
const int n_draft = params.n_draft;
2437

@@ -38,6 +51,7 @@ int main(int argc, char ** argv){
3851

3952
// load the model
4053
std::tie(model, ctx) = llama_init_from_gpt_params(params);
54+
GGML_ASSERT(llama_n_vocab(model) < (1 << 16));
4155

4256
// tokenize the prompt
4357
const bool add_bos = llama_should_add_bos_token(model);
@@ -46,6 +60,55 @@ int main(int argc, char ** argv){
4660
std::vector<llama_token> inp;
4761
inp = ::llama_tokenize(ctx, params.prompt, add_bos, true);
4862

63+
auto update_hashmaps = [](all_token_hashmap * atcs, const llama_token * inp_data, const int inp_size, const int nnew) -> void {
64+
// atcs = all_token_counts: the hashmaps to modify.
65+
// inp_data: the token sequence on which the hashmaps are based.
66+
// inp_size: the current size of inp_data.
67+
// nnew: how many new tokens have been appended to inp_data since the last call to this function.
68+
//
69+
// In order to get correct results inp_data can ONLY BE APPENDED TO.
70+
// Changes in the middle need a complete rebuild.
71+
for (int ngram_size = ngram_min; ngram_size <= ngram_max; ++ngram_size) {
72+
all_token_hashmap * atc = atcs + ngram_size - ngram_min;
73+
74+
const int i_start = std::max(inp_size - nnew, ngram_size);
75+
for (int i = i_start; i < inp_size; ++i) {
76+
const int ngram_start = i - ngram_size;
77+
uint64_t ngram = inp_data[ngram_start];
78+
for (int j = ngram_start; j < ngram_start + ngram_size; ++j) {
79+
const uint64_t ngram_part = inp_data[j];
80+
ngram <<= 16;
81+
ngram |= ngram_part;
82+
}
83+
const llama_token token = inp_data[i];
84+
85+
all_token_hashmap::iterator token_counts_it = atc->find(ngram);
86+
if (token_counts_it == atc->end()) {
87+
token_hashmap token_counts;
88+
token_counts.emplace(token, 1);
89+
atc->emplace(ngram, token_counts);
90+
} else {
91+
token_hashmap::iterator tc_it = token_counts_it->second.find(token);
92+
if (tc_it == token_counts_it->second.end()) {
93+
token_counts_it->second.emplace(token, 1);
94+
} else {
95+
tc_it->second++;
96+
}
97+
}
98+
}
99+
}
100+
};
101+
102+
all_token_hashmap all_token_counts[ngram_max-ngram_min+1];
103+
int64_t t_draft_us = 0;
104+
105+
{
106+
// Fill up hashmaps with tokens from user input:
107+
const int64_t t_start_draft_us = ggml_time_us();
108+
update_hashmaps(all_token_counts, inp.data(), inp.size(), inp.size());
109+
t_draft_us += ggml_time_us() - t_start_draft_us;
110+
}
111+
49112
const int max_context_size = llama_n_ctx(ctx);
50113
const int max_tokens_list_size = max_context_size - 4;
51114

@@ -75,8 +138,6 @@ int main(int argc, char ** argv){
75138
int n_drafted = 0;
76139
int n_accept = 0;
77140

78-
int64_t t_draft_us = 0;
79-
80141
int n_past = inp.size();
81142

82143
bool has_eos = false;
@@ -128,6 +189,12 @@ int main(int argc, char ** argv){
128189
++n_past;
129190
++i_dft;
130191
inp.push_back(id);
192+
{
193+
// Update hashmaps with the newly accepted token:
194+
const int64_t t_start_draft_us = ggml_time_us();
195+
update_hashmaps(all_token_counts, inp.data(), inp.size(), 1);
196+
t_draft_us += ggml_time_us() - t_start_draft_us;
197+
}
131198

132199
if (params.use_color) {
133200
// color accepted draft token
@@ -148,6 +215,12 @@ int main(int argc, char ** argv){
148215
draft.clear();
149216
draft.push_back(id);
150217
inp.push_back(id);
218+
{
219+
// Update hashmaps with the newly accepted token:
220+
const int64_t t_start_draft_us = ggml_time_us();
221+
update_hashmaps(all_token_counts, inp.data(), inp.size(), 1);
222+
t_draft_us += ggml_time_us() - t_start_draft_us;
223+
}
151224
break;
152225
}
153226

@@ -162,44 +235,85 @@ int main(int argc, char ** argv){
162235
llama_batch_clear(batch_tgt);
163236
llama_batch_add(batch_tgt, draft[0], n_past, { 0 }, true);
164237

165-
// generate n_pred tokens through prompt lookup
238+
auto get_token = [](const std::vector<llama_token> inp, const std::vector<llama_token> draft, const size_t i) -> llama_token {
239+
// Helper function to get a token from the combined, speculative sequence of inp and draft.
240+
return i < inp.size() ? inp[i] : draft[1 + i - inp.size()];
241+
};
242+
166243
auto prompt_lookup = [&]() -> void {
244+
// Generate up to n_draft additional tokens through prompt lookup.
245+
// The draft is aborted early if there is no suitable token candidate to continue the draft.
246+
// At the beginning of this function the draft already contains a single token sampled from the model.
167247
const int inp_size = inp.size();
168-
for (int ngram_size = ngram_max ; ngram_size > ngram_min; --ngram_size){
169-
const llama_token * ngram = &inp[inp_size - ngram_size];
170-
171-
for (int i = 0; i <= (int) inp_size - (ngram_size * 2); ++i) {
172-
bool match = true;
173-
for (int j = 0; j < ngram_size; ++j) {
174-
if (inp[i + j] != ngram[j]) {
175-
match = false;
176-
break;
177-
}
248+
249+
while ((int) draft.size()-1 < n_draft) {
250+
bool draft_success = false;
251+
for (int ngram_size = ngram_max; ngram_size >= ngram_min; --ngram_size) {
252+
if (ngram_size > inp_size) {
253+
continue;
254+
}
255+
256+
all_token_hashmap & atc = all_token_counts[ngram_size - ngram_min];
257+
258+
const int ngram_start = inp_size-ngram_size + draft.size()-1;
259+
uint64_t ngram = get_token(inp, draft, ngram_start);
260+
for (int j = ngram_start; j < ngram_start + ngram_size; ++j) {
261+
const uint64_t ngram_part = get_token(inp, draft, j);
262+
ngram <<= 16;
263+
ngram |= ngram_part;
178264
}
179265

180-
if (match) {
181-
const int startIdx = i + ngram_size;
182-
const int endIdx = startIdx + n_draft;
183-
if (endIdx < inp_size) {
184-
for (int j = startIdx; j < endIdx; ++j) {
185-
LOG(" - draft candidate %d: %d\n", j, inp[j]);
186-
draft.push_back(inp[j]);
187-
llama_batch_add(batch_tgt, inp[j], n_past + (j - startIdx) + 1, { 0 }, true);
188-
++n_drafted;
189-
}
190-
return;
266+
all_token_hashmap::iterator token_counts_it = atc.find(ngram);
267+
if (token_counts_it == atc.end()) {
268+
continue;
269+
}
270+
const token_hashmap token_counts = token_counts_it->second;
271+
272+
int max_count = 0;
273+
int sum_count = 0;
274+
llama_token max_token = -1;
275+
276+
for (std::pair<llama_token, int> tc : token_counts) {
277+
const llama_token token = tc.first;
278+
const llama_token count = tc.second;
279+
280+
if (count > max_count) {
281+
max_token = token;
282+
max_count = count;
191283
}
284+
sum_count += count;
285+
}
286+
// Skip this candidate if the sample size is too low:
287+
if (sum_count < draft_min_sample_size[ngram_size-1]) {
288+
continue;
192289
}
290+
// skip this candidate if the empirically most likely token following this token is not likely enough:
291+
if (100*max_count < draft_min_percent[ngram_size-1]*sum_count) {
292+
continue;
293+
}
294+
295+
LOG(" - draft candidate: token=%d count=%d\n", max_token, max_count);
296+
llama_batch_add(batch_tgt, max_token, n_past + draft.size(), { 0 }, true);
297+
draft.push_back(max_token);
298+
draft_success = true;
299+
break;
300+
}
301+
302+
if (!draft_success) {
303+
break;
193304
}
194305
}
195-
return;
196306
};
197307

308+
// Draft already contains a single token sampled from the model:
309+
GGML_ASSERT(draft.size() == 1);
310+
GGML_ASSERT(draft[0] == inp.back());
198311
const int64_t t_start_draft_us = ggml_time_us();
199312

200313
prompt_lookup();
201314

202315
t_draft_us += ggml_time_us() - t_start_draft_us;
316+
n_drafted += draft.size() - 1;
203317

204318
llama_decode(ctx, batch_tgt);
205319
++n_past;

0 commit comments

Comments
 (0)