7
7
#include < cstdio>
8
8
#include < string>
9
9
#include < vector>
10
+ #include < unordered_map>
11
+
12
+ // Data structures to map n-grams to empirical token probabilities:
13
+ typedef std::unordered_map<llama_token, int > token_hashmap; // token -> number of times token has been seen
14
+ typedef std::unordered_map<uint64_t , token_hashmap> all_token_hashmap; // n-gram -> empirical distribution of following tokens
15
+ // n-grams are encoded as 64 bit integers with each of the 4 16 bit sections representing a token id.
16
+ // This way no custom hashing function for the n-grams is needed.
17
+
18
+ // Min/max n-gram size to search for in prompt:
19
+ constexpr int ngram_min = 1 ;
20
+ constexpr int ngram_max = 4 ;
21
+ static_assert (ngram_max <= sizeof (uint64_t )/2, "A 64 bit integer can only hold information for 4 16 bit tokens.");
22
+
23
+ // If sample size or percentage in context are below these thresholds the draft is aborted early:
24
+ constexpr float draft_min_sample_size[ngram_max] = { 2 , 2 , 1 , 1 };
25
+ constexpr float draft_min_percent[ngram_max] = {66 , 50 , 50 , 50 };
10
26
11
27
int main (int argc, char ** argv){
12
28
gpt_params params;
@@ -16,9 +32,6 @@ int main(int argc, char ** argv){
16
32
}
17
33
18
34
// max/min n-grams size to search for in prompt
19
- const int ngram_max = 4 ;
20
- const int ngram_min = 1 ;
21
-
22
35
// length of the candidate / draft sequence, if match is found
23
36
const int n_draft = params.n_draft ;
24
37
@@ -39,6 +52,7 @@ int main(int argc, char ** argv){
39
52
40
53
// load the model
41
54
std::tie (model, ctx) = llama_init_from_gpt_params (params);
55
+ GGML_ASSERT (llama_n_vocab (model) < (1 << 16 ));
42
56
43
57
// tokenize the prompt
44
58
const bool add_bos = llama_should_add_bos_token (model);
@@ -47,6 +61,55 @@ int main(int argc, char ** argv){
47
61
std::vector<llama_token> inp;
48
62
inp = ::llama_tokenize (ctx, params.prompt , add_bos, true );
49
63
64
+ auto update_hashmaps = [](all_token_hashmap * atcs, const llama_token * inp_data, const int inp_size, const int nnew) -> void {
65
+ // atcs = all_token_counts: the hashmaps to modify.
66
+ // inp_data: the token sequence on which the hashmaps are based.
67
+ // inp_size: the current size of inp_data.
68
+ // nnew: how many new tokens have been appended to inp_data since the last call to this function.
69
+ //
70
+ // In order to get correct results inp_data can ONLY BE APPENDED TO.
71
+ // Changes in the middle need a complete rebuild.
72
+ for (int ngram_size = ngram_min; ngram_size <= ngram_max; ++ngram_size) {
73
+ all_token_hashmap * atc = atcs + ngram_size - ngram_min;
74
+
75
+ const int i_start = std::max (inp_size - nnew, ngram_size);
76
+ for (int i = i_start; i < inp_size; ++i) {
77
+ const int ngram_start = i - ngram_size;
78
+ uint64_t ngram = inp_data[ngram_start];
79
+ for (int j = ngram_start; j < ngram_start + ngram_size; ++j) {
80
+ const uint64_t ngram_part = inp_data[j];
81
+ ngram <<= 16 ;
82
+ ngram |= ngram_part;
83
+ }
84
+ const llama_token token = inp_data[i];
85
+
86
+ all_token_hashmap::iterator token_counts_it = atc->find (ngram);
87
+ if (token_counts_it == atc->end ()) {
88
+ token_hashmap token_counts;
89
+ token_counts.emplace (token, 1 );
90
+ atc->emplace (ngram, token_counts);
91
+ } else {
92
+ token_hashmap::iterator tc_it = token_counts_it->second .find (token);
93
+ if (tc_it == token_counts_it->second .end ()) {
94
+ token_counts_it->second .emplace (token, 1 );
95
+ } else {
96
+ tc_it->second ++;
97
+ }
98
+ }
99
+ }
100
+ }
101
+ };
102
+
103
+ all_token_hashmap all_token_counts[ngram_max-ngram_min+1 ];
104
+ int64_t t_draft_us = 0 ;
105
+
106
+ {
107
+ // Fill up hashmaps with tokens from user input:
108
+ const int64_t t_start_draft_us = ggml_time_us ();
109
+ update_hashmaps (all_token_counts, inp.data (), inp.size (), inp.size ());
110
+ t_draft_us += ggml_time_us () - t_start_draft_us;
111
+ }
112
+
50
113
const int max_context_size = llama_n_ctx (ctx);
51
114
const int max_tokens_list_size = max_context_size - 4 ;
52
115
@@ -76,8 +139,6 @@ int main(int argc, char ** argv){
76
139
int n_drafted = 0 ;
77
140
int n_accept = 0 ;
78
141
79
- int64_t t_draft_us = 0 ;
80
-
81
142
int n_past = inp.size ();
82
143
83
144
bool has_eos = false ;
@@ -129,6 +190,12 @@ int main(int argc, char ** argv){
129
190
++n_past;
130
191
++i_dft;
131
192
inp.push_back (id);
193
+ {
194
+ // Update hashmaps with the newly accepted token:
195
+ const int64_t t_start_draft_us = ggml_time_us ();
196
+ update_hashmaps (all_token_counts, inp.data (), inp.size (), 1 );
197
+ t_draft_us += ggml_time_us () - t_start_draft_us;
198
+ }
132
199
133
200
if (params.use_color ) {
134
201
// color accepted draft token
@@ -149,6 +216,12 @@ int main(int argc, char ** argv){
149
216
draft.clear ();
150
217
draft.push_back (id);
151
218
inp.push_back (id);
219
+ {
220
+ // Update hashmaps with the newly accepted token:
221
+ const int64_t t_start_draft_us = ggml_time_us ();
222
+ update_hashmaps (all_token_counts, inp.data (), inp.size (), 1 );
223
+ t_draft_us += ggml_time_us () - t_start_draft_us;
224
+ }
152
225
break ;
153
226
}
154
227
@@ -163,44 +236,85 @@ int main(int argc, char ** argv){
163
236
llama_batch_clear (batch_tgt);
164
237
llama_batch_add (batch_tgt, draft[0 ], n_past, { 0 }, true );
165
238
166
- // generate n_pred tokens through prompt lookup
239
+ auto get_token = [](const std::vector<llama_token> inp, const std::vector<llama_token> draft, const size_t i) -> llama_token {
240
+ // Helper function to get a token from the combined, speculative sequence of inp and draft.
241
+ return i < inp.size () ? inp[i] : draft[1 + i - inp.size ()];
242
+ };
243
+
167
244
auto prompt_lookup = [&]() -> void {
245
+ // Generate up to n_draft additional tokens through prompt lookup.
246
+ // The draft is aborted early if there is no suitable token candidate to continue the draft.
247
+ // At the beginning of this function the draft already contains a single token sampled from the model.
168
248
const int inp_size = inp.size ();
169
- for (int ngram_size = ngram_max ; ngram_size > ngram_min; --ngram_size){
170
- const llama_token * ngram = &inp[inp_size - ngram_size];
171
-
172
- for (int i = 0 ; i <= (int ) inp_size - (ngram_size * 2 ); ++i) {
173
- bool match = true ;
174
- for (int j = 0 ; j < ngram_size; ++j) {
175
- if (inp[i + j] != ngram[j]) {
176
- match = false ;
177
- break ;
178
- }
249
+
250
+ while ((int ) draft.size ()-1 < n_draft) {
251
+ bool draft_success = false ;
252
+ for (int ngram_size = ngram_max; ngram_size >= ngram_min; --ngram_size) {
253
+ if (ngram_size > inp_size) {
254
+ continue ;
255
+ }
256
+
257
+ all_token_hashmap & atc = all_token_counts[ngram_size - ngram_min];
258
+
259
+ const int ngram_start = inp_size-ngram_size + draft.size ()-1 ;
260
+ uint64_t ngram = get_token (inp, draft, ngram_start);
261
+ for (int j = ngram_start; j < ngram_start + ngram_size; ++j) {
262
+ const uint64_t ngram_part = get_token (inp, draft, j);
263
+ ngram <<= 16 ;
264
+ ngram |= ngram_part;
179
265
}
180
266
181
- if (match) {
182
- const int startIdx = i + ngram_size;
183
- const int endIdx = startIdx + n_draft;
184
- if (endIdx < inp_size) {
185
- for (int j = startIdx; j < endIdx; ++j) {
186
- LOG (" - draft candidate %d: %d\n " , j, inp[j]);
187
- draft.push_back (inp[j]);
188
- llama_batch_add (batch_tgt, inp[j], n_past + (j - startIdx) + 1 , { 0 }, true );
189
- ++n_drafted;
190
- }
191
- return ;
267
+ all_token_hashmap::iterator token_counts_it = atc.find (ngram);
268
+ if (token_counts_it == atc.end ()) {
269
+ continue ;
270
+ }
271
+ const token_hashmap token_counts = token_counts_it->second ;
272
+
273
+ int max_count = 0 ;
274
+ int sum_count = 0 ;
275
+ llama_token max_token = -1 ;
276
+
277
+ for (std::pair<llama_token, int > tc : token_counts) {
278
+ const llama_token token = tc.first ;
279
+ const llama_token count = tc.second ;
280
+
281
+ if (count > max_count) {
282
+ max_token = token;
283
+ max_count = count;
192
284
}
285
+ sum_count += count;
286
+ }
287
+ // Skip this candidate if the sample size is too low:
288
+ if (sum_count < draft_min_sample_size[ngram_size-1 ]) {
289
+ continue ;
193
290
}
291
+ // skip this candidate if the empirically most likely token following this token is not likely enough:
292
+ if (100 *max_count < draft_min_percent[ngram_size-1 ]*sum_count) {
293
+ continue ;
294
+ }
295
+
296
+ LOG (" - draft candidate: token=%d count=%d\n " , max_token, max_count);
297
+ llama_batch_add (batch_tgt, max_token, n_past + draft.size (), { 0 }, true );
298
+ draft.push_back (max_token);
299
+ draft_success = true ;
300
+ break ;
301
+ }
302
+
303
+ if (!draft_success) {
304
+ break ;
194
305
}
195
306
}
196
- return ;
197
307
};
198
308
309
+ // Draft already contains a single token sampled from the model:
310
+ GGML_ASSERT (draft.size () == 1 );
311
+ GGML_ASSERT (draft[0 ] == inp.back ());
199
312
const int64_t t_start_draft_us = ggml_time_us ();
200
313
201
314
prompt_lookup ();
202
315
203
316
t_draft_us += ggml_time_us () - t_start_draft_us;
317
+ n_drafted += draft.size () - 1 ;
204
318
205
319
llama_decode (ctx, batch_tgt);
206
320
++n_past;
0 commit comments