7
7
#include < cstdio>
8
8
#include < string>
9
9
#include < vector>
10
+ #include < unordered_map>
11
+
12
+ // Data structures to map n-grams to empirical token probabilities:
13
+ typedef std::unordered_map<llama_token, int > token_hashmap; // token -> number of times token has been seen
14
+ typedef std::unordered_map<uint64_t , token_hashmap> all_token_hashmap; // n-gram -> empirical distribution of following tokens
15
+ // n-grams are encoded as 64 bit integers with each of the 4 16 bit sections representing a token id.
16
+ // This way no custom hashing function for the n-grams is needed.
17
+
18
+ // Min/max n-gram size to search for in prompt:
19
+ constexpr int ngram_min = 1 ;
20
+ constexpr int ngram_max = 4 ;
21
+ static_assert (ngram_max <= sizeof (uint64_t )/2, "A 64 bit integer can only hold information for 4 16 bit tokens.");
22
+
23
+ // If sample size or percentage in context are below these thresholds the draft is aborted early:
24
+ constexpr float draft_min_sample_size[ngram_max] = { 2 , 2 , 1 , 1 };
25
+ constexpr float draft_min_percent[ngram_max] = {66 , 50 , 50 , 50 };
10
26
11
27
int main (int argc, char ** argv){
12
28
gpt_params params;
@@ -16,9 +32,6 @@ int main(int argc, char ** argv){
16
32
}
17
33
18
34
// max/min n-grams size to search for in prompt
19
- const int ngram_max = 4 ;
20
- const int ngram_min = 1 ;
21
-
22
35
// length of the candidate / draft sequence, if match is found
23
36
const int n_draft = params.n_draft ;
24
37
@@ -38,6 +51,7 @@ int main(int argc, char ** argv){
38
51
39
52
// load the model
40
53
std::tie (model, ctx) = llama_init_from_gpt_params (params);
54
+ GGML_ASSERT (llama_n_vocab (model) < (1 << 16 ));
41
55
42
56
// tokenize the prompt
43
57
const bool add_bos = llama_should_add_bos_token (model);
@@ -46,6 +60,55 @@ int main(int argc, char ** argv){
46
60
std::vector<llama_token> inp;
47
61
inp = ::llama_tokenize (ctx, params.prompt , add_bos, true );
48
62
63
+ auto update_hashmaps = [](all_token_hashmap * atcs, const llama_token * inp_data, const int inp_size, const int nnew) -> void {
64
+ // atcs = all_token_counts: the hashmaps to modify.
65
+ // inp_data: the token sequence on which the hashmaps are based.
66
+ // inp_size: the current size of inp_data.
67
+ // nnew: how many new tokens have been appended to inp_data since the last call to this function.
68
+ //
69
+ // In order to get correct results inp_data can ONLY BE APPENDED TO.
70
+ // Changes in the middle need a complete rebuild.
71
+ for (int ngram_size = ngram_min; ngram_size <= ngram_max; ++ngram_size) {
72
+ all_token_hashmap * atc = atcs + ngram_size - ngram_min;
73
+
74
+ const int i_start = std::max (inp_size - nnew, ngram_size);
75
+ for (int i = i_start; i < inp_size; ++i) {
76
+ const int ngram_start = i - ngram_size;
77
+ uint64_t ngram = inp_data[ngram_start];
78
+ for (int j = ngram_start; j < ngram_start + ngram_size; ++j) {
79
+ const uint64_t ngram_part = inp_data[j];
80
+ ngram <<= 16 ;
81
+ ngram |= ngram_part;
82
+ }
83
+ const llama_token token = inp_data[i];
84
+
85
+ all_token_hashmap::iterator token_counts_it = atc->find (ngram);
86
+ if (token_counts_it == atc->end ()) {
87
+ token_hashmap token_counts;
88
+ token_counts.emplace (token, 1 );
89
+ atc->emplace (ngram, token_counts);
90
+ } else {
91
+ token_hashmap::iterator tc_it = token_counts_it->second .find (token);
92
+ if (tc_it == token_counts_it->second .end ()) {
93
+ token_counts_it->second .emplace (token, 1 );
94
+ } else {
95
+ tc_it->second ++;
96
+ }
97
+ }
98
+ }
99
+ }
100
+ };
101
+
102
+ all_token_hashmap all_token_counts[ngram_max-ngram_min+1 ];
103
+ int64_t t_draft_us = 0 ;
104
+
105
+ {
106
+ // Fill up hashmaps with tokens from user input:
107
+ const int64_t t_start_draft_us = ggml_time_us ();
108
+ update_hashmaps (all_token_counts, inp.data (), inp.size (), inp.size ());
109
+ t_draft_us += ggml_time_us () - t_start_draft_us;
110
+ }
111
+
49
112
const int max_context_size = llama_n_ctx (ctx);
50
113
const int max_tokens_list_size = max_context_size - 4 ;
51
114
@@ -75,8 +138,6 @@ int main(int argc, char ** argv){
75
138
int n_drafted = 0 ;
76
139
int n_accept = 0 ;
77
140
78
- int64_t t_draft_us = 0 ;
79
-
80
141
int n_past = inp.size ();
81
142
82
143
bool has_eos = false ;
@@ -128,6 +189,12 @@ int main(int argc, char ** argv){
128
189
++n_past;
129
190
++i_dft;
130
191
inp.push_back (id);
192
+ {
193
+ // Update hashmaps with the newly accepted token:
194
+ const int64_t t_start_draft_us = ggml_time_us ();
195
+ update_hashmaps (all_token_counts, inp.data (), inp.size (), 1 );
196
+ t_draft_us += ggml_time_us () - t_start_draft_us;
197
+ }
131
198
132
199
if (params.use_color ) {
133
200
// color accepted draft token
@@ -148,6 +215,12 @@ int main(int argc, char ** argv){
148
215
draft.clear ();
149
216
draft.push_back (id);
150
217
inp.push_back (id);
218
+ {
219
+ // Update hashmaps with the newly accepted token:
220
+ const int64_t t_start_draft_us = ggml_time_us ();
221
+ update_hashmaps (all_token_counts, inp.data (), inp.size (), 1 );
222
+ t_draft_us += ggml_time_us () - t_start_draft_us;
223
+ }
151
224
break ;
152
225
}
153
226
@@ -162,44 +235,85 @@ int main(int argc, char ** argv){
162
235
llama_batch_clear (batch_tgt);
163
236
llama_batch_add (batch_tgt, draft[0 ], n_past, { 0 }, true );
164
237
165
- // generate n_pred tokens through prompt lookup
238
+ auto get_token = [](const std::vector<llama_token> inp, const std::vector<llama_token> draft, const size_t i) -> llama_token {
239
+ // Helper function to get a token from the combined, speculative sequence of inp and draft.
240
+ return i < inp.size () ? inp[i] : draft[1 + i - inp.size ()];
241
+ };
242
+
166
243
auto prompt_lookup = [&]() -> void {
244
+ // Generate up to n_draft additional tokens through prompt lookup.
245
+ // The draft is aborted early if there is no suitable token candidate to continue the draft.
246
+ // At the beginning of this function the draft already contains a single token sampled from the model.
167
247
const int inp_size = inp.size ();
168
- for (int ngram_size = ngram_max ; ngram_size > ngram_min; --ngram_size){
169
- const llama_token * ngram = &inp[inp_size - ngram_size];
170
-
171
- for (int i = 0 ; i <= (int ) inp_size - (ngram_size * 2 ); ++i) {
172
- bool match = true ;
173
- for (int j = 0 ; j < ngram_size; ++j) {
174
- if (inp[i + j] != ngram[j]) {
175
- match = false ;
176
- break ;
177
- }
248
+
249
+ while ((int ) draft.size ()-1 < n_draft) {
250
+ bool draft_success = false ;
251
+ for (int ngram_size = ngram_max; ngram_size >= ngram_min; --ngram_size) {
252
+ if (ngram_size > inp_size) {
253
+ continue ;
254
+ }
255
+
256
+ all_token_hashmap & atc = all_token_counts[ngram_size - ngram_min];
257
+
258
+ const int ngram_start = inp_size-ngram_size + draft.size ()-1 ;
259
+ uint64_t ngram = get_token (inp, draft, ngram_start);
260
+ for (int j = ngram_start; j < ngram_start + ngram_size; ++j) {
261
+ const uint64_t ngram_part = get_token (inp, draft, j);
262
+ ngram <<= 16 ;
263
+ ngram |= ngram_part;
178
264
}
179
265
180
- if (match) {
181
- const int startIdx = i + ngram_size;
182
- const int endIdx = startIdx + n_draft;
183
- if (endIdx < inp_size) {
184
- for (int j = startIdx; j < endIdx; ++j) {
185
- LOG (" - draft candidate %d: %d\n " , j, inp[j]);
186
- draft.push_back (inp[j]);
187
- llama_batch_add (batch_tgt, inp[j], n_past + (j - startIdx) + 1 , { 0 }, true );
188
- ++n_drafted;
189
- }
190
- return ;
266
+ all_token_hashmap::iterator token_counts_it = atc.find (ngram);
267
+ if (token_counts_it == atc.end ()) {
268
+ continue ;
269
+ }
270
+ const token_hashmap token_counts = token_counts_it->second ;
271
+
272
+ int max_count = 0 ;
273
+ int sum_count = 0 ;
274
+ llama_token max_token = -1 ;
275
+
276
+ for (std::pair<llama_token, int > tc : token_counts) {
277
+ const llama_token token = tc.first ;
278
+ const llama_token count = tc.second ;
279
+
280
+ if (count > max_count) {
281
+ max_token = token;
282
+ max_count = count;
191
283
}
284
+ sum_count += count;
285
+ }
286
+ // Skip this candidate if the sample size is too low:
287
+ if (sum_count < draft_min_sample_size[ngram_size-1 ]) {
288
+ continue ;
192
289
}
290
+ // skip this candidate if the empirically most likely token following this token is not likely enough:
291
+ if (100 *max_count < draft_min_percent[ngram_size-1 ]*sum_count) {
292
+ continue ;
293
+ }
294
+
295
+ LOG (" - draft candidate: token=%d count=%d\n " , max_token, max_count);
296
+ llama_batch_add (batch_tgt, max_token, n_past + draft.size (), { 0 }, true );
297
+ draft.push_back (max_token);
298
+ draft_success = true ;
299
+ break ;
300
+ }
301
+
302
+ if (!draft_success) {
303
+ break ;
193
304
}
194
305
}
195
- return ;
196
306
};
197
307
308
+ // Draft already contains a single token sampled from the model:
309
+ GGML_ASSERT (draft.size () == 1 );
310
+ GGML_ASSERT (draft[0 ] == inp.back ());
198
311
const int64_t t_start_draft_us = ggml_time_us ();
199
312
200
313
prompt_lookup ();
201
314
202
315
t_draft_us += ggml_time_us () - t_start_draft_us;
316
+ n_drafted += draft.size () - 1 ;
203
317
204
318
llama_decode (ctx, batch_tgt);
205
319
++n_past;
0 commit comments