5
5
#include < cmath>
6
6
#include < cstdint>
7
7
#include < cstdio>
8
+ #include < fstream>
8
9
#include < string>
9
10
#include < vector>
10
11
#include < unordered_map>
11
12
12
13
// Data structures to map n-grams to empirical token probabilities:
13
- typedef std::unordered_map<llama_token, int > token_hashmap; // token -> number of times token has been seen
14
+ typedef std::unordered_map<llama_token, int32_t > token_hashmap; // token -> number of times token has been seen
14
15
typedef std::unordered_map<uint64_t , token_hashmap> all_token_hashmap; // n-gram -> empirical distribution of following tokens
15
16
// n-grams are encoded as 64 bit integers with each of the 4 16 bit sections representing a token id.
16
17
// This way no custom hashing function for the n-grams is needed.
@@ -22,7 +23,7 @@ static_assert(ngram_max <= sizeof(uint64_t)/2, "A 64 bit integer can only hold i
22
23
23
24
// If sample size or percentage in context are below these thresholds the draft is aborted early:
24
25
constexpr float draft_min_sample_size[ngram_max] = { 2 , 2 , 1 , 1 };
25
- constexpr float draft_min_percent[ngram_max] = {66 , 50 , 50 , 50 };
26
+ constexpr float draft_min_percent[ngram_max] = {50 , 50 , 50 , 50 };
26
27
27
28
int main (int argc, char ** argv){
28
29
gpt_params params;
@@ -100,12 +101,43 @@ int main(int argc, char ** argv){
100
101
};
101
102
102
103
all_token_hashmap all_token_counts[ngram_max-ngram_min+1 ];
104
+ all_token_hashmap static_all_token_counts;
103
105
int64_t t_draft_us = 0 ;
104
106
105
107
{
106
108
// Fill up hashmaps with tokens from user input:
107
109
const int64_t t_start_draft_us = ggml_time_us ();
108
110
update_hashmaps (all_token_counts, inp.data (), inp.size (), inp.size ());
111
+
112
+ const char * hashmap_file_name = " lookup.bin" ;
113
+ std::ifstream hashmap_file (hashmap_file_name, std::ios::binary);
114
+ if (!hashmap_file) {
115
+ fprintf (stderr, " error: failed to open file '%s'\n " , hashmap_file_name);
116
+ exit (1 );
117
+ }
118
+ uint64_t ngram;
119
+ int32_t ntokens;
120
+ llama_token token;
121
+ int32_t count;
122
+
123
+ char * ngramc = reinterpret_cast <char *>(&ngram);
124
+ char * ntokensc = reinterpret_cast <char *>(&ntokens);
125
+ char * tokenc = reinterpret_cast <char *>(&token);
126
+ char * countc = reinterpret_cast <char *>(&count);
127
+ while (hashmap_file.read (ngramc, sizeof (uint64_t ))) {
128
+ GGML_ASSERT (hashmap_file.read (ntokensc, sizeof (int32_t )));
129
+ token_hashmap token_counts;
130
+
131
+ for (int i = 0 ; i < ntokens; ++i) {
132
+ GGML_ASSERT (hashmap_file.read (tokenc, sizeof (llama_token)));
133
+ GGML_ASSERT (hashmap_file.read (countc, sizeof (int32_t )));
134
+ token_counts.emplace (token, count);
135
+ }
136
+
137
+ static_all_token_counts.emplace (ngram, token_counts);
138
+ }
139
+ GGML_ASSERT (hashmap_file.eof ());
140
+
109
141
t_draft_us += ggml_time_us () - t_start_draft_us;
110
142
}
111
143
@@ -248,6 +280,20 @@ int main(int argc, char ** argv){
248
280
249
281
while ((int ) draft.size ()-1 < n_draft) {
250
282
bool draft_success = false ;
283
+
284
+ const int static_ngram_start = inp_size-2 + draft.size ()-1 ;
285
+ uint64_t static_ngram = get_token (inp, draft, static_ngram_start);
286
+ for (int j = static_ngram_start; j < static_ngram_start + 2 ; ++j) {
287
+ const uint64_t ngram_part = get_token (inp, draft, j);
288
+ static_ngram <<= 16 ;
289
+ static_ngram |= ngram_part;
290
+ }
291
+ all_token_hashmap::iterator static_token_counts_it = static_all_token_counts.find (static_ngram);
292
+ token_hashmap static_token_counts;
293
+ if (static_token_counts_it != static_all_token_counts.end ()) {
294
+ static_token_counts = static_token_counts_it->second ;
295
+ }
296
+
251
297
for (int ngram_size = ngram_max; ngram_size >= ngram_min; --ngram_size) {
252
298
if (ngram_size > inp_size) {
253
299
continue ;
@@ -270,16 +316,21 @@ int main(int argc, char ** argv){
270
316
const token_hashmap token_counts = token_counts_it->second ;
271
317
272
318
int max_count = 0 ;
319
+ int max_count_static = 0 ;
273
320
int sum_count = 0 ;
274
321
llama_token max_token = -1 ;
275
322
276
323
for (std::pair<llama_token, int > tc : token_counts) {
277
324
const llama_token token = tc.first ;
278
- const llama_token count = tc.second ;
279
325
280
- if (count > max_count) {
281
- max_token = token;
282
- max_count = count;
326
+ token_hashmap::iterator stc_it = static_token_counts.find (token);
327
+ const int32_t count = tc.second ;
328
+ const int32_t count_static = stc_it != static_token_counts.end () ? 100 *stc_it->second : 1 ;
329
+
330
+ if (count*count_static > max_count*max_count_static) {
331
+ max_token = token;
332
+ max_count = count;
333
+ max_count_static = count_static;
283
334
}
284
335
sum_count += count;
285
336
}
@@ -299,6 +350,38 @@ int main(int argc, char ** argv){
299
350
break ;
300
351
}
301
352
353
+ if (!draft_success) {
354
+ int max_count = 0 ;
355
+ int sum_count = 0 ;
356
+ llama_token max_token = -1 ;
357
+
358
+ for (std::pair<llama_token, int > tc : static_token_counts) {
359
+ const llama_token token = tc.first ;
360
+ const int32_t count = tc.second ;
361
+
362
+ if (count > max_count) {
363
+ max_token = token;
364
+ max_count = count;
365
+ }
366
+ sum_count += count;
367
+ }
368
+
369
+ // Skip this candidate if the sample size is too low:
370
+ if (sum_count < draft_min_sample_size[2 -1 ]) {
371
+ break ;
372
+ }
373
+ // skip this candidate if the empirically most likely token following this token is not likely enough:
374
+ if (100 *max_count < draft_min_percent[2 -1 ]*sum_count) {
375
+ break ;
376
+ }
377
+
378
+ LOG (" - draft candidate: token=%d count=%d\n " , max_token, max_count);
379
+ llama_batch_add (batch_tgt, max_token, n_past + draft.size (), { 0 }, true );
380
+ draft.push_back (max_token);
381
+ draft_success = true ;
382
+ break ;
383
+ }
384
+
302
385
if (!draft_success) {
303
386
break ;
304
387
}
0 commit comments