5
5
#include < cmath>
6
6
#include < cstdint>
7
7
#include < cstdio>
8
+ #include < fstream>
8
9
#include < string>
9
10
#include < vector>
10
11
#include < unordered_map>
11
12
12
13
// Data structures to map n-grams to empirical token probabilities:
13
- typedef std::unordered_map<llama_token, int > token_hashmap; // token -> number of times token has been seen
14
+ typedef std::unordered_map<llama_token, int32_t > token_hashmap; // token -> number of times token has been seen
14
15
typedef std::unordered_map<uint64_t , token_hashmap> all_token_hashmap; // n-gram -> empirical distribution of following tokens
15
16
// n-grams are encoded as 64 bit integers with each of the 4 16 bit sections representing a token id.
16
17
// This way no custom hashing function for the n-grams is needed.
@@ -22,7 +23,7 @@ static_assert(ngram_max <= sizeof(uint64_t)/2, "A 64 bit integer can only hold i
22
23
23
24
// If sample size or percentage in context are below these thresholds the draft is aborted early:
24
25
constexpr float draft_min_sample_size[ngram_max] = { 2 , 2 , 1 , 1 };
25
- constexpr float draft_min_percent[ngram_max] = {66 , 50 , 50 , 50 };
26
+ constexpr float draft_min_percent[ngram_max] = {50 , 50 , 50 , 50 };
26
27
27
28
int main (int argc, char ** argv){
28
29
gpt_params params;
@@ -101,12 +102,43 @@ int main(int argc, char ** argv){
101
102
};
102
103
103
104
all_token_hashmap all_token_counts[ngram_max-ngram_min+1 ];
105
+ all_token_hashmap static_all_token_counts;
104
106
int64_t t_draft_us = 0 ;
105
107
106
108
{
107
109
// Fill up hashmaps with tokens from user input:
108
110
const int64_t t_start_draft_us = ggml_time_us ();
109
111
update_hashmaps (all_token_counts, inp.data (), inp.size (), inp.size ());
112
+
113
+ const char * hashmap_file_name = " lookup.bin" ;
114
+ std::ifstream hashmap_file (hashmap_file_name, std::ios::binary);
115
+ if (!hashmap_file) {
116
+ fprintf (stderr, " error: failed to open file '%s'\n " , hashmap_file_name);
117
+ exit (1 );
118
+ }
119
+ uint64_t ngram;
120
+ int32_t ntokens;
121
+ llama_token token;
122
+ int32_t count;
123
+
124
+ char * ngramc = reinterpret_cast <char *>(&ngram);
125
+ char * ntokensc = reinterpret_cast <char *>(&ntokens);
126
+ char * tokenc = reinterpret_cast <char *>(&token);
127
+ char * countc = reinterpret_cast <char *>(&count);
128
+ while (hashmap_file.read (ngramc, sizeof (uint64_t ))) {
129
+ GGML_ASSERT (hashmap_file.read (ntokensc, sizeof (int32_t )));
130
+ token_hashmap token_counts;
131
+
132
+ for (int i = 0 ; i < ntokens; ++i) {
133
+ GGML_ASSERT (hashmap_file.read (tokenc, sizeof (llama_token)));
134
+ GGML_ASSERT (hashmap_file.read (countc, sizeof (int32_t )));
135
+ token_counts.emplace (token, count);
136
+ }
137
+
138
+ static_all_token_counts.emplace (ngram, token_counts);
139
+ }
140
+ GGML_ASSERT (hashmap_file.eof ());
141
+
110
142
t_draft_us += ggml_time_us () - t_start_draft_us;
111
143
}
112
144
@@ -249,6 +281,20 @@ int main(int argc, char ** argv){
249
281
250
282
while ((int ) draft.size ()-1 < n_draft) {
251
283
bool draft_success = false ;
284
+
285
+ const int static_ngram_start = inp_size-2 + draft.size ()-1 ;
286
+ uint64_t static_ngram = get_token (inp, draft, static_ngram_start);
287
+ for (int j = static_ngram_start; j < static_ngram_start + 2 ; ++j) {
288
+ const uint64_t ngram_part = get_token (inp, draft, j);
289
+ static_ngram <<= 16 ;
290
+ static_ngram |= ngram_part;
291
+ }
292
+ all_token_hashmap::iterator static_token_counts_it = static_all_token_counts.find (static_ngram);
293
+ token_hashmap static_token_counts;
294
+ if (static_token_counts_it != static_all_token_counts.end ()) {
295
+ static_token_counts = static_token_counts_it->second ;
296
+ }
297
+
252
298
for (int ngram_size = ngram_max; ngram_size >= ngram_min; --ngram_size) {
253
299
if (ngram_size > inp_size) {
254
300
continue ;
@@ -271,16 +317,21 @@ int main(int argc, char ** argv){
271
317
const token_hashmap token_counts = token_counts_it->second ;
272
318
273
319
int max_count = 0 ;
320
+ int max_count_static = 0 ;
274
321
int sum_count = 0 ;
275
322
llama_token max_token = -1 ;
276
323
277
324
for (std::pair<llama_token, int > tc : token_counts) {
278
325
const llama_token token = tc.first ;
279
- const llama_token count = tc.second ;
280
326
281
- if (count > max_count) {
282
- max_token = token;
283
- max_count = count;
327
+ token_hashmap::iterator stc_it = static_token_counts.find (token);
328
+ const int32_t count = tc.second ;
329
+ const int32_t count_static = stc_it != static_token_counts.end () ? 100 *stc_it->second : 1 ;
330
+
331
+ if (count*count_static > max_count*max_count_static) {
332
+ max_token = token;
333
+ max_count = count;
334
+ max_count_static = count_static;
284
335
}
285
336
sum_count += count;
286
337
}
@@ -300,6 +351,38 @@ int main(int argc, char ** argv){
300
351
break ;
301
352
}
302
353
354
+ if (!draft_success) {
355
+ int max_count = 0 ;
356
+ int sum_count = 0 ;
357
+ llama_token max_token = -1 ;
358
+
359
+ for (std::pair<llama_token, int > tc : static_token_counts) {
360
+ const llama_token token = tc.first ;
361
+ const int32_t count = tc.second ;
362
+
363
+ if (count > max_count) {
364
+ max_token = token;
365
+ max_count = count;
366
+ }
367
+ sum_count += count;
368
+ }
369
+
370
+ // Skip this candidate if the sample size is too low:
371
+ if (sum_count < draft_min_sample_size[2 -1 ]) {
372
+ break ;
373
+ }
374
+ // skip this candidate if the empirically most likely token following this token is not likely enough:
375
+ if (100 *max_count < draft_min_percent[2 -1 ]*sum_count) {
376
+ break ;
377
+ }
378
+
379
+ LOG (" - draft candidate: token=%d count=%d\n " , max_token, max_count);
380
+ llama_batch_add (batch_tgt, max_token, n_past + draft.size (), { 0 }, true );
381
+ draft.push_back (max_token);
382
+ draft_success = true ;
383
+ break ;
384
+ }
385
+
303
386
if (!draft_success) {
304
387
break ;
305
388
}
0 commit comments