Skip to content

Commit 42d7092

Browse files
JohannesGaesslerNexesenex
authored andcommitted
lookup: single sequence -> tree of sequences
1 parent 7ce31a4 commit 42d7092

File tree

4 files changed

+308
-106
lines changed

4 files changed

+308
-106
lines changed

common/ngram-cache.cpp

Lines changed: 173 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -52,52 +52,101 @@ static llama_token get_token(const std::vector<llama_token> & inp, const std::ve
5252
return i < inp.size() ? inp[i] : draft[1 + i - inp.size()];
5353
}
5454

55-
// If sample size or percentage are below these thresholds the draft is aborted early:
56-
constexpr int draft_min_sample_size_lax[LLAMA_NGRAM_MAX] = { 2, 2, 1, 1};
57-
constexpr int draft_min_percent_lax[LLAMA_NGRAM_MAX] = {66, 50, 50, 50};
55+
// Sample size and percentage must meet these thresholds to be added to the draft tree:
56+
constexpr int draft_min_sample_size_lax[LLAMA_NGRAM_MAX] = { 1, 1, 1, 1};
57+
constexpr int draft_min_percent_lax[LLAMA_NGRAM_MAX] = {20, 20, 10, 10};
5858
constexpr int draft_min_sample_size_strict[LLAMA_NGRAM_MAX] = { 4, 3, 2, 2};
59-
constexpr int draft_min_percent_strict[LLAMA_NGRAM_MAX] = {75, 66, 66, 66};
59+
constexpr int draft_min_percent_strict[LLAMA_NGRAM_MAX] = {50, 50, 50, 50};
60+
61+
struct draft_candidate {
62+
llama_draft_t draft;
63+
float nll;
64+
int nsampled;
65+
};
66+
67+
struct compare_draft_candidate {
68+
bool operator()(const draft_candidate & a, const draft_candidate & b){
69+
if (a.nsampled > b.nsampled) {
70+
return true;
71+
}
72+
if (a.nsampled < b.nsampled) {
73+
return false;
74+
}
75+
return a.nll < b.nll;
76+
}
77+
};
78+
79+
// Helper function that tries to draft tokens from only the static ngram cache:
80+
static void try_draft(
81+
llama_ngram_cache & nc_static, const llama_ngram & ngram_static,
82+
const int * min_sample_size, const int * min_percent, const draft_candidate & cp,
83+
const int ngram_min, std::vector<draft_candidate> & drafts_new) {
84+
85+
const int nsc = (ngram_min + LLAMA_NGRAM_STATIC) - (cp.draft.size() - 1);
86+
if (nsc < (ngram_min + LLAMA_NGRAM_STATIC + 1)/2) {
87+
return;
88+
}
6089

61-
// Helper function that tries to draft a token from only the static ngram cache:
62-
static llama_token try_draft(llama_ngram_cache & nc_static, const llama_ngram ngram_static) {
6390
llama_ngram_cache::iterator part_static_it = nc_static.find(ngram_static);
6491
if (part_static_it == nc_static.end()) {
65-
return -1;
92+
return;
6693
}
6794
const llama_ngram_cache_part part_static = part_static_it->second;
6895

69-
int max_count_static = 0;
7096
int sum_count_static = 0;
71-
llama_token max_token = -1;
7297

7398
for (std::pair<llama_token, int> token_count_static : part_static) {
74-
const llama_token token = token_count_static.first;
7599
const int32_t count_static = token_count_static.second;
76100

77-
if (count_static > max_count_static) {
78-
max_token = token;
79-
max_count_static = count_static;
80-
}
81101
sum_count_static += count_static;
82102
}
83103

84-
if (sum_count_static < draft_min_sample_size_lax[LLAMA_NGRAM_STATIC-1]) {
85-
return -1;
86-
}
87-
if (100*max_count_static < draft_min_percent_lax[LLAMA_NGRAM_STATIC-1]*sum_count_static) {
88-
return -1;
104+
for (std::pair<llama_token, int> token_count_static : part_static) {
105+
const llama_token token = token_count_static.first;
106+
const int32_t count_static = token_count_static.second;
107+
108+
if (sum_count_static < min_sample_size[LLAMA_NGRAM_STATIC-1]) {
109+
continue;
110+
}
111+
if (100*count_static < min_percent[LLAMA_NGRAM_STATIC-1]*sum_count_static) {
112+
continue;;
113+
}
114+
115+
draft_candidate cc;
116+
for (const llama_token & t : cp.draft) {
117+
cc.draft.push_back(t);
118+
}
119+
cc.draft.push_back(token);
120+
cc.nll = cp.nll - logf(1.0f*count_static/sum_count_static);
121+
cc.nsampled = nsc;
122+
123+
bool duplicate = false;
124+
for (const draft_candidate & co : drafts_new) {
125+
if (co.draft == cc.draft) {
126+
duplicate = true;
127+
break;
128+
}
129+
}
130+
if (duplicate) {
131+
continue;
132+
}
133+
134+
drafts_new.push_back(cc);
89135
}
90-
return max_token;
91136
}
92137

93-
// Try to draft a token from primary cache (context/dynamic), validate with static cache:
94-
static llama_token try_draft(
138+
// Try to draft tokens from primary cache (context/dynamic), validate with static cache:
139+
static void try_draft(
95140
llama_ngram_cache & nc_primary, const std::vector<llama_ngram> & ngrams_primary, llama_ngram_cache_part & part_static,
96-
const int * min_sample_size, const int * min_percent) {
141+
const int * min_sample_size, const int * min_percent, const draft_candidate & cp,
142+
const int ngram_min, std::vector<draft_candidate> & drafts_new) {
97143

98-
llama_token drafted_token = -1;
144+
for (int i = ngrams_primary.size()-1; i >= 0; --i) {
145+
const int nsc = (ngram_min + i) - (cp.draft.size() - 1);
146+
if (nsc < (ngram_min + i + 1)/2) {
147+
break;
148+
}
99149

100-
for (int i = ngrams_primary.size()-1; i >= 0 && drafted_token == -1; --i) {
101150
const llama_ngram ngram_primary = ngrams_primary[i];
102151

103152
llama_ngram_cache::iterator part_primary_it = nc_primary.find(ngram_primary);
@@ -106,10 +155,8 @@ static llama_token try_draft(
106155
}
107156
const llama_ngram_cache_part part_primary = part_primary_it->second;
108157

109-
int max_count_primary = 0;
110-
int max_count_static = 0;
111158
int sum_count_primary = 0;
112-
llama_token max_token = -1;
159+
int sum_count_prod = 0;
113160

114161
for (std::pair<llama_token, int> token_count_primary : part_primary) {
115162
const llama_token token = token_count_primary.first;
@@ -119,44 +166,100 @@ static llama_token try_draft(
119166
const int32_t count_primary = token_count_primary.second;
120167
const int32_t count_static = token_count_static_it != part_static.end() ? 100*token_count_static_it->second : 1;
121168

122-
if (count_primary*count_static > max_count_primary*max_count_static) {
123-
max_token = token;
124-
max_count_primary = count_primary;
125-
max_count_static = count_static;
126-
}
127169
sum_count_primary += count_primary;
170+
sum_count_prod += count_primary*count_static;
128171
}
129172

130-
if (sum_count_primary < min_sample_size[i]) {
131-
continue;
132-
}
133-
if (100*max_count_primary < min_percent[i]*sum_count_primary) {
134-
continue;;
173+
for (std::pair<llama_token, int> token_count_primary : part_primary) {
174+
const llama_token token = token_count_primary.first;
175+
176+
llama_ngram_cache_part::iterator token_count_static_it = part_static.find(token);
177+
178+
const int32_t count_primary = token_count_primary.second;
179+
const int32_t count_static = token_count_static_it != part_static.end() ? 100*token_count_static_it->second : 1;
180+
const int32_t count_prod = count_primary*count_static;
181+
182+
if (sum_count_primary < min_sample_size[i]) {
183+
continue;
184+
}
185+
186+
if (100*count_prod < min_percent[i]*sum_count_prod) {
187+
continue;
188+
}
189+
190+
draft_candidate cc;
191+
for (const llama_token & t : cp.draft) {
192+
cc.draft.push_back(t);
193+
}
194+
cc.draft.push_back(token);
195+
cc.nll = cp.nll - logf(1.0f*count_prod/sum_count_prod);
196+
cc.nsampled = nsc;
197+
198+
bool duplicate = false;
199+
for (const draft_candidate & co : drafts_new) {
200+
if (co.draft == cc.draft) {
201+
duplicate = true;
202+
break;
203+
}
204+
}
205+
if (duplicate) {
206+
continue;
207+
}
208+
209+
drafts_new.push_back(cc);
135210
}
136-
drafted_token = max_token;
137211
}
138-
139-
return drafted_token;
140212
}
141213

142214
void llama_ngram_cache_draft(
143-
std::vector<llama_token> & inp, std::vector<llama_token> & draft, int n_draft, int ngram_min, int ngram_max,
215+
std::vector<llama_token> & inp, std::vector<std::vector<llama_token>> & drafts, int n_draft, int ngram_min, int ngram_max,
144216
llama_ngram_cache & nc_context, llama_ngram_cache & nc_dynamic, llama_ngram_cache & nc_static
145217
) {
146-
GGML_ASSERT(draft.size() == 1);
218+
if (n_draft == 0) {
219+
return;
220+
}
221+
222+
GGML_ASSERT(drafts.size() == 1);
223+
GGML_ASSERT(drafts[0].size() == 1);
147224
const int inp_size = inp.size();
148225

149-
if (inp_size < LLAMA_NGRAM_STATIC) {
226+
if (inp_size < std::max(ngram_max, LLAMA_NGRAM_STATIC)) {
150227
return;
151228
}
152229

153-
while ((int) draft.size()-1 < n_draft) {
154-
llama_token drafted_token = -1;
230+
// While building the tree, store drafts with potential children in a heap:
231+
std::vector<draft_candidate> drafts_wip;
232+
233+
{
234+
draft_candidate candidate;
235+
candidate.draft.push_back(drafts[0][0]);
236+
candidate.nll = 0.0f;
237+
candidate.nsampled = LLAMA_NGRAM_MAX;
238+
drafts_wip.push_back(candidate);
239+
}
240+
241+
drafts.clear();
242+
int i_draft = 0;
243+
244+
// Temporarily hold new drafts in vector, only add part of them in the last iteration to exactly meet n_draft.
245+
std::vector<draft_candidate> drafts_new;
155246

156-
const int ngram_start_static = inp_size-LLAMA_NGRAM_STATIC + draft.size()-1;
247+
while (i_draft + ((int) drafts_new.size()) < n_draft && !(drafts_wip.empty() && drafts_new.empty())) {
248+
for (const draft_candidate & ndc : drafts_new) {
249+
drafts_wip.push_back(ndc);
250+
std::push_heap(drafts_wip.begin(), drafts_wip.end(), compare_draft_candidate());
251+
i_draft++;
252+
}
253+
drafts_new.clear();
254+
255+
std::pop_heap(drafts_wip.begin(), drafts_wip.end(), compare_draft_candidate());
256+
const draft_candidate cp = drafts_wip.back(); // cp = candidate parent
257+
drafts_wip.pop_back();
258+
259+
const int ngram_start_static = inp_size-LLAMA_NGRAM_STATIC + cp.draft.size()-1;
157260
llama_ngram ngram_static;
158261
for (int j = ngram_start_static; j < ngram_start_static + LLAMA_NGRAM_STATIC; ++j) {
159-
ngram_static.tokens[j-ngram_start_static] = get_token(inp, draft, j);
262+
ngram_static.tokens[j-ngram_start_static] = get_token(inp, cp.draft, j);
160263
}
161264
llama_ngram_cache::iterator part_static_it = nc_static.find(ngram_static);
162265
llama_ngram_cache_part part_static;
@@ -167,29 +270,37 @@ void llama_ngram_cache_draft(
167270
// cd = context + dynamic
168271
std::vector<llama_ngram> ngrams_cd;
169272
for (int ngram_size_cd = ngram_min; ngram_size_cd <= ngram_max; ++ngram_size_cd) {
170-
const int ngram_start_cd = inp_size-ngram_size_cd + draft.size()-1;
273+
const int ngram_start_cd = inp_size-ngram_size_cd + cp.draft.size()-1;
171274
llama_ngram ngram_cd;
172275
for (int j = ngram_start_cd; j < ngram_start_cd + ngram_size_cd; ++j) {
173-
ngram_cd.tokens[j-ngram_start_cd] = get_token(inp, draft, j);
276+
ngram_cd.tokens[j-ngram_start_cd] = get_token(inp, cp.draft, j);
174277
}
175278
ngrams_cd.push_back(ngram_cd);
176279
}
177-
if (drafted_token == -1) {
178-
drafted_token = try_draft(nc_context, ngrams_cd, part_static, draft_min_sample_size_lax, draft_min_percent_lax);
179-
}
180-
if (drafted_token == -1) {
181-
drafted_token = try_draft(nc_dynamic, ngrams_cd, part_static, draft_min_sample_size_strict, draft_min_percent_strict);
182-
}
183-
if (drafted_token == -1) {
184-
drafted_token = try_draft(nc_static, ngram_static);
280+
281+
try_draft(nc_context, ngrams_cd, part_static, draft_min_sample_size_lax, draft_min_percent_lax, cp, ngram_min, drafts_new);
282+
try_draft(nc_dynamic, ngrams_cd, part_static, draft_min_sample_size_strict, draft_min_percent_lax, cp, ngram_min, drafts_new);
283+
try_draft(nc_static, ngram_static, draft_min_sample_size_strict, draft_min_percent_strict, cp, ngram_min, drafts_new);
284+
285+
if (drafts_new.empty()) {
286+
drafts.push_back(cp.draft);
287+
i_draft++;
185288
}
289+
}
186290

187-
if (drafted_token == -1) {
291+
for (const draft_candidate & dc : drafts_wip) { // dc = draft child
292+
drafts.push_back(dc.draft);
293+
}
294+
295+
std::sort(drafts_new.begin(), drafts_new.end(), compare_draft_candidate());
296+
297+
for (const draft_candidate & dc : drafts_new) {
298+
drafts.push_back(dc.draft);
299+
i_draft++;
300+
301+
if (i_draft >= n_draft) {
188302
break;
189303
}
190-
191-
LOG(" - draft candidate: token=%d\n", drafted_token);
192-
draft.push_back(drafted_token);
193304
}
194305
}
195306

common/ngram-cache.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ typedef std::unordered_map<llama_token, int32_t> llama_ngram_cache_part;
6060
// n-gram -> empirical distribution of following tokens
6161
typedef std::unordered_map<llama_ngram, llama_ngram_cache_part, llama_ngram_hash_function> llama_ngram_cache;
6262

63+
typedef std::vector<llama_token> llama_draft_t;
6364

6465
// Update an ngram cache with tokens.
6566
// ngram_cache: the cache to modify.
@@ -82,7 +83,7 @@ void llama_ngram_cache_update(
8283
// nc_dynamic: ngram cache based on previous user generations.
8384
// nc_static: ngram cache generated from a large text corpus, used for validation.
8485
void llama_ngram_cache_draft(
85-
std::vector<llama_token> & inp, std::vector<llama_token> & draft, int n_draft, int ngram_min, int ngram_max,
86+
std::vector<llama_token> & inp, std::vector<llama_draft_t> & drafts, int n_draft, int ngram_min, int ngram_max,
8687
llama_ngram_cache & nc_context, llama_ngram_cache & nc_dynamic, llama_ngram_cache & nc_static);
8788

8889
// Save an ngram cache to a file.

examples/lookup/lookup-stats.cpp

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <cstdint>
99
#include <cstdio>
1010
#include <fstream>
11+
#include <set>
1112
#include <string>
1213
#include <vector>
1314
#include <unordered_map>
@@ -80,22 +81,42 @@ int main(int argc, char ** argv){
8081

8182
while ((int) pseudo_output.size() < n_ctx) {
8283
// Simulate drafting and decoding from draft:
83-
std::vector<llama_token> draft;
84-
draft.push_back(pseudo_output.back());
84+
std::vector<llama_draft_t> drafts;
85+
llama_draft_t draft0;
86+
draft0.push_back(pseudo_output.back());
87+
drafts.push_back(draft0);
8588

8689
{
8790
const int64_t t_start_draft_us = ggml_time_us();
88-
llama_ngram_cache_draft(pseudo_output, draft, n_draft, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, ngram_cache_context, ngram_cache_dynamic, ngram_cache_static);
91+
llama_ngram_cache_draft(
92+
pseudo_output, drafts, n_draft, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, ngram_cache_context, ngram_cache_dynamic, ngram_cache_static);
8993
t_draft_us += ggml_time_us() - t_start_draft_us;
9094
}
95+
GGML_ASSERT((int) drafts.size() <= n_draft || n_draft <= 0);
9196

92-
n_drafted += draft.size() - 1;
97+
// FIXME wrong KV mask for converging sequences (does not seem to happen in practice).
98+
for (int j = 1; j < n_draft + 1; ++j) {
99+
std::set<llama_token> seen_tokens;
93100

94-
for (size_t j = 1; j < draft.size() && (int) pseudo_output.size() < n_ctx; ++j) {
101+
for (const llama_draft_t & draft : drafts) {
102+
if (j < (int) draft.size() && seen_tokens.find(draft[j]) == seen_tokens.end()) {
103+
seen_tokens.emplace(draft[j]);
104+
n_drafted++;
105+
}
106+
}
107+
}
108+
109+
for (int j = 1; j < n_draft + 1 && (int) pseudo_output.size() < n_ctx; ++j) {
95110
const llama_token ground_truth = inp_slice[pseudo_output.size()];
96-
const llama_token drafted = draft[j];
97111

98-
if (ground_truth != drafted) {
112+
bool ground_truth_in_drafts = false;
113+
for (const llama_draft_t & draft : drafts) {
114+
if (j < (int) draft.size() && draft[j] == ground_truth) {
115+
ground_truth_in_drafts = true;
116+
break;
117+
}
118+
}
119+
if (!ground_truth_in_drafts) {
99120
break;
100121
}
101122

@@ -119,7 +140,7 @@ int main(int argc, char ** argv){
119140
}
120141
}
121142

122-
draft.erase(draft.begin());
143+
drafts.clear();
123144

124145
}
125146
if (i_start > 0 && i_start / 100000 != (i_start - n_ctx) / 100000) {

0 commit comments

Comments
 (0)