@@ -52,52 +52,101 @@ static llama_token get_token(const std::vector<llama_token> & inp, const std::ve
52
52
return i < inp.size () ? inp[i] : draft[1 + i - inp.size ()];
53
53
}
54
54
55
- // If sample size or percentage are below these thresholds the draft is aborted early :
56
- constexpr int draft_min_sample_size_lax[LLAMA_NGRAM_MAX] = { 2 , 2 , 1 , 1 };
57
- constexpr int draft_min_percent_lax[LLAMA_NGRAM_MAX] = {66 , 50 , 50 , 50 };
55
+ // Sample size and percentage must meet these thresholds to be added to the draft tree :
56
+ constexpr int draft_min_sample_size_lax[LLAMA_NGRAM_MAX] = { 1 , 1 , 1 , 1 };
57
+ constexpr int draft_min_percent_lax[LLAMA_NGRAM_MAX] = {20 , 20 , 10 , 10 };
58
58
constexpr int draft_min_sample_size_strict[LLAMA_NGRAM_MAX] = { 4 , 3 , 2 , 2 };
59
- constexpr int draft_min_percent_strict[LLAMA_NGRAM_MAX] = {75 , 66 , 66 , 66 };
59
+ constexpr int draft_min_percent_strict[LLAMA_NGRAM_MAX] = {50 , 50 , 50 , 50 };
60
+
61
+ struct draft_candidate {
62
+ llama_draft_t draft;
63
+ float nll;
64
+ int nsampled;
65
+ };
66
+
67
+ struct compare_draft_candidate {
68
+ bool operator ()(const draft_candidate & a, const draft_candidate & b){
69
+ if (a.nsampled > b.nsampled ) {
70
+ return true ;
71
+ }
72
+ if (a.nsampled < b.nsampled ) {
73
+ return false ;
74
+ }
75
+ return a.nll < b.nll ;
76
+ }
77
+ };
78
+
79
+ // Helper function that tries to draft tokens from only the static ngram cache:
80
+ static void try_draft (
81
+ llama_ngram_cache & nc_static, const llama_ngram & ngram_static,
82
+ const int * min_sample_size, const int * min_percent, const draft_candidate & cp,
83
+ const int ngram_min, std::vector<draft_candidate> & drafts_new) {
84
+
85
+ const int nsc = (ngram_min + LLAMA_NGRAM_STATIC) - (cp.draft .size () - 1 );
86
+ if (nsc < (ngram_min + LLAMA_NGRAM_STATIC + 1 )/2 ) {
87
+ return ;
88
+ }
60
89
61
- // Helper function that tries to draft a token from only the static ngram cache:
62
- static llama_token try_draft (llama_ngram_cache & nc_static, const llama_ngram ngram_static) {
63
90
llama_ngram_cache::iterator part_static_it = nc_static.find (ngram_static);
64
91
if (part_static_it == nc_static.end ()) {
65
- return - 1 ;
92
+ return ;
66
93
}
67
94
const llama_ngram_cache_part part_static = part_static_it->second ;
68
95
69
- int max_count_static = 0 ;
70
96
int sum_count_static = 0 ;
71
- llama_token max_token = -1 ;
72
97
73
98
for (std::pair<llama_token, int > token_count_static : part_static) {
74
- const llama_token token = token_count_static.first ;
75
99
const int32_t count_static = token_count_static.second ;
76
100
77
- if (count_static > max_count_static) {
78
- max_token = token;
79
- max_count_static = count_static;
80
- }
81
101
sum_count_static += count_static;
82
102
}
83
103
84
- if (sum_count_static < draft_min_sample_size_lax[LLAMA_NGRAM_STATIC-1 ]) {
85
- return -1 ;
86
- }
87
- if (100 *max_count_static < draft_min_percent_lax[LLAMA_NGRAM_STATIC-1 ]*sum_count_static) {
88
- return -1 ;
104
+ for (std::pair<llama_token, int > token_count_static : part_static) {
105
+ const llama_token token = token_count_static.first ;
106
+ const int32_t count_static = token_count_static.second ;
107
+
108
+ if (sum_count_static < min_sample_size[LLAMA_NGRAM_STATIC-1 ]) {
109
+ continue ;
110
+ }
111
+ if (100 *count_static < min_percent[LLAMA_NGRAM_STATIC-1 ]*sum_count_static) {
112
+ continue ;;
113
+ }
114
+
115
+ draft_candidate cc;
116
+ for (const llama_token & t : cp.draft ) {
117
+ cc.draft .push_back (t);
118
+ }
119
+ cc.draft .push_back (token);
120
+ cc.nll = cp.nll - logf (1 .0f *count_static/sum_count_static);
121
+ cc.nsampled = nsc;
122
+
123
+ bool duplicate = false ;
124
+ for (const draft_candidate & co : drafts_new) {
125
+ if (co.draft == cc.draft ) {
126
+ duplicate = true ;
127
+ break ;
128
+ }
129
+ }
130
+ if (duplicate) {
131
+ continue ;
132
+ }
133
+
134
+ drafts_new.push_back (cc);
89
135
}
90
- return max_token;
91
136
}
92
137
93
- // Try to draft a token from primary cache (context/dynamic), validate with static cache:
94
- static llama_token try_draft (
138
+ // Try to draft tokens from primary cache (context/dynamic), validate with static cache:
139
+ static void try_draft (
95
140
llama_ngram_cache & nc_primary, const std::vector<llama_ngram> & ngrams_primary, llama_ngram_cache_part & part_static,
96
- const int * min_sample_size, const int * min_percent) {
141
+ const int * min_sample_size, const int * min_percent, const draft_candidate & cp,
142
+ const int ngram_min, std::vector<draft_candidate> & drafts_new) {
97
143
98
- llama_token drafted_token = -1 ;
144
+ for (int i = ngrams_primary.size ()-1 ; i >= 0 ; --i) {
145
+ const int nsc = (ngram_min + i) - (cp.draft .size () - 1 );
146
+ if (nsc < (ngram_min + i + 1 )/2 ) {
147
+ break ;
148
+ }
99
149
100
- for (int i = ngrams_primary.size ()-1 ; i >= 0 && drafted_token == -1 ; --i) {
101
150
const llama_ngram ngram_primary = ngrams_primary[i];
102
151
103
152
llama_ngram_cache::iterator part_primary_it = nc_primary.find (ngram_primary);
@@ -106,10 +155,8 @@ static llama_token try_draft(
106
155
}
107
156
const llama_ngram_cache_part part_primary = part_primary_it->second ;
108
157
109
- int max_count_primary = 0 ;
110
- int max_count_static = 0 ;
111
158
int sum_count_primary = 0 ;
112
- llama_token max_token = - 1 ;
159
+ int sum_count_prod = 0 ;
113
160
114
161
for (std::pair<llama_token, int > token_count_primary : part_primary) {
115
162
const llama_token token = token_count_primary.first ;
@@ -119,44 +166,100 @@ static llama_token try_draft(
119
166
const int32_t count_primary = token_count_primary.second ;
120
167
const int32_t count_static = token_count_static_it != part_static.end () ? 100 *token_count_static_it->second : 1 ;
121
168
122
- if (count_primary*count_static > max_count_primary*max_count_static) {
123
- max_token = token;
124
- max_count_primary = count_primary;
125
- max_count_static = count_static;
126
- }
127
169
sum_count_primary += count_primary;
170
+ sum_count_prod += count_primary*count_static;
128
171
}
129
172
130
- if (sum_count_primary < min_sample_size[i]) {
131
- continue ;
132
- }
133
- if (100 *max_count_primary < min_percent[i]*sum_count_primary) {
134
- continue ;;
173
+ for (std::pair<llama_token, int > token_count_primary : part_primary) {
174
+ const llama_token token = token_count_primary.first ;
175
+
176
+ llama_ngram_cache_part::iterator token_count_static_it = part_static.find (token);
177
+
178
+ const int32_t count_primary = token_count_primary.second ;
179
+ const int32_t count_static = token_count_static_it != part_static.end () ? 100 *token_count_static_it->second : 1 ;
180
+ const int32_t count_prod = count_primary*count_static;
181
+
182
+ if (sum_count_primary < min_sample_size[i]) {
183
+ continue ;
184
+ }
185
+
186
+ if (100 *count_prod < min_percent[i]*sum_count_prod) {
187
+ continue ;
188
+ }
189
+
190
+ draft_candidate cc;
191
+ for (const llama_token & t : cp.draft ) {
192
+ cc.draft .push_back (t);
193
+ }
194
+ cc.draft .push_back (token);
195
+ cc.nll = cp.nll - logf (1 .0f *count_prod/sum_count_prod);
196
+ cc.nsampled = nsc;
197
+
198
+ bool duplicate = false ;
199
+ for (const draft_candidate & co : drafts_new) {
200
+ if (co.draft == cc.draft ) {
201
+ duplicate = true ;
202
+ break ;
203
+ }
204
+ }
205
+ if (duplicate) {
206
+ continue ;
207
+ }
208
+
209
+ drafts_new.push_back (cc);
135
210
}
136
- drafted_token = max_token;
137
211
}
138
-
139
- return drafted_token;
140
212
}
141
213
142
214
void llama_ngram_cache_draft (
143
- std::vector<llama_token> & inp, std::vector<llama_token> & draft , int n_draft, int ngram_min, int ngram_max,
215
+ std::vector<llama_token> & inp, std::vector<std::vector< llama_token>> & drafts , int n_draft, int ngram_min, int ngram_max,
144
216
llama_ngram_cache & nc_context, llama_ngram_cache & nc_dynamic, llama_ngram_cache & nc_static
145
217
) {
146
- GGML_ASSERT (draft.size () == 1 );
218
+ if (n_draft == 0 ) {
219
+ return ;
220
+ }
221
+
222
+ GGML_ASSERT (drafts.size () == 1 );
223
+ GGML_ASSERT (drafts[0 ].size () == 1 );
147
224
const int inp_size = inp.size ();
148
225
149
- if (inp_size < LLAMA_NGRAM_STATIC) {
226
+ if (inp_size < std::max (ngram_max, LLAMA_NGRAM_STATIC) ) {
150
227
return ;
151
228
}
152
229
153
- while ((int ) draft.size ()-1 < n_draft) {
154
- llama_token drafted_token = -1 ;
230
+ // While building the tree, store drafts with potential children in a heap:
231
+ std::vector<draft_candidate> drafts_wip;
232
+
233
+ {
234
+ draft_candidate candidate;
235
+ candidate.draft .push_back (drafts[0 ][0 ]);
236
+ candidate.nll = 0 .0f ;
237
+ candidate.nsampled = LLAMA_NGRAM_MAX;
238
+ drafts_wip.push_back (candidate);
239
+ }
240
+
241
+ drafts.clear ();
242
+ int i_draft = 0 ;
243
+
244
+ // Temporarily hold new drafts in vector, only add part of them in the last iteration to exactly meet n_draft.
245
+ std::vector<draft_candidate> drafts_new;
155
246
156
- const int ngram_start_static = inp_size-LLAMA_NGRAM_STATIC + draft.size ()-1 ;
247
+ while (i_draft + ((int ) drafts_new.size ()) < n_draft && !(drafts_wip.empty () && drafts_new.empty ())) {
248
+ for (const draft_candidate & ndc : drafts_new) {
249
+ drafts_wip.push_back (ndc);
250
+ std::push_heap (drafts_wip.begin (), drafts_wip.end (), compare_draft_candidate ());
251
+ i_draft++;
252
+ }
253
+ drafts_new.clear ();
254
+
255
+ std::pop_heap (drafts_wip.begin (), drafts_wip.end (), compare_draft_candidate ());
256
+ const draft_candidate cp = drafts_wip.back (); // cp = candidate parent
257
+ drafts_wip.pop_back ();
258
+
259
+ const int ngram_start_static = inp_size-LLAMA_NGRAM_STATIC + cp.draft .size ()-1 ;
157
260
llama_ngram ngram_static;
158
261
for (int j = ngram_start_static; j < ngram_start_static + LLAMA_NGRAM_STATIC; ++j) {
159
- ngram_static.tokens [j-ngram_start_static] = get_token (inp, draft, j);
262
+ ngram_static.tokens [j-ngram_start_static] = get_token (inp, cp. draft , j);
160
263
}
161
264
llama_ngram_cache::iterator part_static_it = nc_static.find (ngram_static);
162
265
llama_ngram_cache_part part_static;
@@ -167,29 +270,37 @@ void llama_ngram_cache_draft(
167
270
// cd = context + dynamic
168
271
std::vector<llama_ngram> ngrams_cd;
169
272
for (int ngram_size_cd = ngram_min; ngram_size_cd <= ngram_max; ++ngram_size_cd) {
170
- const int ngram_start_cd = inp_size-ngram_size_cd + draft.size ()-1 ;
273
+ const int ngram_start_cd = inp_size-ngram_size_cd + cp. draft .size ()-1 ;
171
274
llama_ngram ngram_cd;
172
275
for (int j = ngram_start_cd; j < ngram_start_cd + ngram_size_cd; ++j) {
173
- ngram_cd.tokens [j-ngram_start_cd] = get_token (inp, draft, j);
276
+ ngram_cd.tokens [j-ngram_start_cd] = get_token (inp, cp. draft , j);
174
277
}
175
278
ngrams_cd.push_back (ngram_cd);
176
279
}
177
- if (drafted_token == - 1 ) {
178
- drafted_token = try_draft (nc_context, ngrams_cd, part_static, draft_min_sample_size_lax, draft_min_percent_lax);
179
- }
180
- if (drafted_token == - 1 ) {
181
- drafted_token = try_draft (nc_dynamic, ngrams_cd, part_static, draft_min_sample_size_strict, draft_min_percent_strict);
182
- }
183
- if (drafted_token == - 1 ) {
184
- drafted_token = try_draft (nc_static, ngram_static) ;
280
+
281
+ try_draft (nc_context, ngrams_cd, part_static, draft_min_sample_size_lax, draft_min_percent_lax, cp, ngram_min, drafts_new );
282
+ try_draft (nc_dynamic, ngrams_cd, part_static, draft_min_sample_size_strict, draft_min_percent_lax, cp, ngram_min, drafts_new);
283
+ try_draft (nc_static, ngram_static, draft_min_sample_size_strict, draft_min_percent_strict, cp, ngram_min, drafts_new);
284
+
285
+ if (drafts_new. empty ()) {
286
+ drafts. push_back (cp. draft );
287
+ i_draft++ ;
185
288
}
289
+ }
186
290
187
- if (drafted_token == -1 ) {
291
+ for (const draft_candidate & dc : drafts_wip) { // dc = draft child
292
+ drafts.push_back (dc.draft );
293
+ }
294
+
295
+ std::sort (drafts_new.begin (), drafts_new.end (), compare_draft_candidate ());
296
+
297
+ for (const draft_candidate & dc : drafts_new) {
298
+ drafts.push_back (dc.draft );
299
+ i_draft++;
300
+
301
+ if (i_draft >= n_draft) {
188
302
break ;
189
303
}
190
-
191
- LOG (" - draft candidate: token=%d\n " , drafted_token);
192
- draft.push_back (drafted_token);
193
304
}
194
305
}
195
306
0 commit comments