Skip to content

Commit 0d4d0c1

Browse files
committed
speculative : simplify (cont)
ggml-ci
1 parent e4c122b commit 0d4d0c1

File tree

5 files changed

+56
-67
lines changed

5 files changed

+56
-67
lines changed

common/sampling.cpp

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,7 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
320320
return cur_p.data[cur_p.selected].id;
321321
}
322322

323-
std::vector<llama_token> common_sampler_sample_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const std::vector<llama_token> & draft, bool grammar_first) {
323+
std::vector<llama_token> common_sampler_sample_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first) {
324324
GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1");
325325

326326
std::vector<llama_token> result;
@@ -342,23 +342,10 @@ std::vector<llama_token> common_sampler_sample_n(struct common_sampler * gsmpl,
342342
return result;
343343
}
344344

345-
std::vector<llama_token> common_sampler_sample_n(struct common_sampler * gsmpl, struct llama_context * ctx, const struct llama_batch & batch, bool grammar_first) {
346-
std::vector<int> idxs;
347-
idxs.reserve(batch.n_tokens);
348-
349-
std::vector<llama_token> draft;
350-
draft.reserve(batch.n_tokens);
351-
352-
for (int i = 0; i < batch.n_tokens; i++) {
353-
if (batch.logits[i] == 0) {
354-
continue;
355-
}
356-
357-
if (idxs.size() > 0) {
358-
GGML_ASSERT(batch.pos[idxs.back()] + 1 == batch.pos[i]);
359-
draft.push_back(batch.token[i]);
360-
}
361-
idxs.push_back(i);
345+
std::vector<llama_token> common_sampler_sample_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first) {
346+
std::vector<int> idxs(draft.size() + 1);
347+
for (size_t i = 0; i < idxs.size(); ++i) {
348+
idxs[i] = i;
362349
}
363350

364351
return common_sampler_sample_n(gsmpl, ctx, idxs, draft, grammar_first);

common/sampling.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,10 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
7171
//
7272
// returns at least 1 token, up to idxs.size()
7373
//
74-
std::vector<llama_token> common_sampler_sample_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const std::vector<llama_token> & draft, bool grammar_first = false);
74+
std::vector<llama_token> common_sampler_sample_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first = false);
7575

76-
std::vector<llama_token> common_sampler_sample_n(struct common_sampler * gsmpl, struct llama_context * ctx, const struct llama_batch & batch, bool grammar_first = false);
76+
// assume idxs == [ 0, 1, 2, ..., draft.size() ]
77+
std::vector<llama_token> common_sampler_sample_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first = false);
7778

7879
uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl);
7980

common/speculative.cpp

Lines changed: 15 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -10,24 +10,19 @@
1010
#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
1111

1212
struct common_speculative {
13-
struct common_speculative_params params;
14-
15-
llama_batch batch;
16-
1713
struct llama_context * ctx;
1814
struct common_sampler * smpl;
1915

16+
llama_batch batch;
2017
llama_tokens prompt;
2118
};
2219

2320
struct common_speculative * common_speculative_init(
24-
struct common_speculative_params params,
2521
struct llama_context * ctx_dft) {
2622
auto * result = new common_speculative {
27-
/* .params = */ params,
28-
/* .batch = */ llama_batch_init(llama_n_batch(ctx_dft), 0, 1),
2923
/* .ctx = */ ctx_dft,
3024
/* .smpl = */ nullptr,
25+
/* .batch = */ llama_batch_init(llama_n_batch(ctx_dft), 0, 1),
3126
/* .prompt = */ {},
3227
};
3328

@@ -130,12 +125,11 @@ bool common_speculative_are_compatible(
130125
return true;
131126
}
132127

133-
void common_speculative_add_draft(
128+
llama_tokens common_speculative_gen_draft(
134129
struct common_speculative * spec,
135-
struct llama_batch & batch_tgt,
130+
struct common_speculative_params params,
136131
const llama_tokens & prompt_tgt,
137-
llama_token id_last,
138-
llama_token n_past_tgt) {
132+
llama_token id_last) {
139133
auto & batch = spec->batch;
140134
auto & ctx = spec->ctx;
141135
auto & smpl = spec->smpl;
@@ -144,7 +138,7 @@ void common_speculative_add_draft(
144138
int reuse_i = 0;
145139
int reuse_n = 0;
146140

147-
const int n_ctx = llama_n_ctx(ctx) - spec->params.n_draft;
141+
const int n_ctx = llama_n_ctx(ctx) - params.n_draft;
148142

149143
const int i_start = std::max<int>(0, (int) prompt_tgt.size() - n_ctx);
150144

@@ -156,7 +150,7 @@ void common_speculative_add_draft(
156150
cur++;
157151
}
158152

159-
if ((cur >= spec->params.n_reuse || prompt_tgt.size() <= n_ctx) && cur > reuse_n) {
153+
if ((cur >= params.n_reuse || prompt_tgt.size() <= n_ctx) && cur > reuse_n) {
160154
reuse_i = i;
161155
reuse_n = cur;
162156
}
@@ -207,8 +201,11 @@ void common_speculative_add_draft(
207201

208202
common_sampler_reset(smpl);
209203

204+
llama_tokens result;
205+
result.reserve(params.n_draft);
206+
210207
// sample n_draft tokens from the draft model
211-
for (int i = 0; i < spec->params.n_draft; ++i) {
208+
for (int i = 0; i < params.n_draft; ++i) {
212209
common_batch_clear(batch);
213210

214211
common_sampler_sample(smpl, ctx, 0, true);
@@ -224,15 +221,15 @@ void common_speculative_add_draft(
224221
const llama_token id = cur_p->data[0].id;
225222

226223
// only collect very high-confidence draft tokens
227-
if (cur_p->data[0].p < spec->params.p_min) {
224+
if (cur_p->data[0].p < params.p_min) {
228225
break;
229226
}
230227

231228
common_sampler_accept(smpl, id, true);
232229

233-
common_batch_add(batch_tgt, id, n_past_tgt + i, { 0 }, true);
230+
result.push_back(id);
234231

235-
if (batch_tgt.n_tokens > spec->params.n_draft) {
232+
if (result.size() >= params.n_draft) {
236233
break;
237234
}
238235

@@ -244,9 +241,5 @@ void common_speculative_add_draft(
244241
prompt.push_back(id);
245242
}
246243

247-
// don't waste time on small batches
248-
// TODO: do not evaluate the draft model for that many rounds
249-
if (batch_tgt.n_tokens < spec->params.n_min) {
250-
batch_tgt.n_tokens = 1;
251-
}
244+
return result;
252245
}

common/speculative.h

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,12 @@ struct common_speculative;
77

88
struct common_speculative_params {
99
int n_draft = 16;
10-
int n_min = 5; // do not add drafts smaller than this, TODO: leave this to user?
1110
int n_reuse = 256;
1211

1312
float p_min = 0.9f;
1413
};
1514

16-
struct common_speculative * common_speculative_init(
17-
struct common_speculative_params params,
18-
struct llama_context * ctx_dft);
15+
struct common_speculative * common_speculative_init(struct llama_context * ctx_dft);
1916

2017
void common_speculative_free(struct common_speculative * spec);
2118

@@ -25,9 +22,8 @@ bool common_speculative_are_compatible(
2522

2623
// sample up to n_draft tokens and add them to the batch using the draft model
2724
//
28-
void common_speculative_add_draft(
25+
llama_tokens common_speculative_gen_draft(
2926
struct common_speculative * spec,
30-
struct llama_batch & batch_tgt,
27+
struct common_speculative_params params,
3128
const llama_tokens & prompt,
32-
llama_token id_last,
33-
llama_token n_past_tgt);
29+
llama_token id_last);

examples/speculative-simple/speculative-simple.cpp

Lines changed: 29 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
int main(int argc, char ** argv) {
1414
common_params params;
1515

16+
// minimum size of the draft to use
17+
const int n_min = 5;
18+
1619
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_SPECULATIVE)) {
1720
return 1;
1821
}
@@ -92,31 +95,29 @@ int main(int argc, char ** argv) {
9295
// everything until here is standard initialization
9396
// the relevant stuff for speculative decoding starts here
9497

95-
const int n_input = inp.size();
96-
9798
const auto t_enc_start = ggml_time_us();
9899

99100
// target model sampling context
100101
struct common_sampler * smpl = common_sampler_init(model_tgt, params.sparams);
101102

102103
// eval the prompt
103-
llama_decode(ctx_tgt, llama_batch_get_one(inp.data(), n_input - 1));
104+
llama_decode(ctx_tgt, llama_batch_get_one(inp.data(), inp.size() - 1));
104105

105106
// note: keep the last token separate!
106107
llama_token id_last = inp.back();
107108

108-
auto prompt_dft = std::vector<llama_token>(inp.begin(), inp.end() - 1);
109+
// all tokens currently in the target context
110+
auto prompt_tgt = std::vector<llama_token>(inp.begin(), inp.end() - 1);
109111

110112
int n_past = inp.size() - 1;
111113

112114
// init the speculator
113115
struct common_speculative_params params_spec;
114116
params_spec.n_draft = n_draft;
115-
params_spec.n_min = 5;
116117
params_spec.n_reuse = 256;
117118
params_spec.p_min = 0.9f;
118119

119-
struct common_speculative * spec = common_speculative_init(params_spec, ctx_dft);
120+
struct common_speculative * spec = common_speculative_init(ctx_dft);
120121

121122
llama_batch batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, 1);
122123

@@ -125,21 +126,30 @@ int main(int argc, char ** argv) {
125126
const auto t_dec_start = ggml_time_us();
126127

127128
while (true) {
128-
// always have a token to evaluate from before
129-
common_batch_clear(batch_tgt);
130-
common_batch_add (batch_tgt, id_last, n_past, { 0 }, true);
131-
132-
// optionally, append draft tokens to the target batch
129+
// optionally, generate draft tokens that can be appended to the target batch
133130
//
134131
// this is the most important part of the speculation. the more probable tokens that are provided here
135132
// the better the performance will be. in theory, this computation can be performed asynchronously and even
136133
// offloaded to a remote device. it doesn't even have to be based on an LLM. instead, it can provide tokens
137134
// from a cache or lookup tables.
138135
//
139-
common_speculative_add_draft(spec, batch_tgt, prompt_dft, id_last, n_past + 1);
136+
llama_tokens draft = common_speculative_gen_draft(spec, params_spec, prompt_tgt, id_last);
137+
138+
// always have a token to evaluate from before - id_last
139+
common_batch_clear(batch_tgt);
140+
common_batch_add (batch_tgt, id_last, n_past++, { 0 }, true);
140141

141142
// evaluate the target model on [id_last, draft0, draft1, ..., draftN-1]
142143
{
144+
// do not waste time on small drafts
145+
if (draft.size() < n_min) {
146+
draft.clear();
147+
}
148+
149+
for (size_t i = 0; i < draft.size(); ++i) {
150+
common_batch_add(batch_tgt, draft[i], n_past + i, { 0 }, true);
151+
}
152+
143153
//LOG_DBG("target batch: %s\n", string_from(ctx_tgt, batch_tgt).c_str());
144154

145155
llama_decode(ctx_tgt, batch_tgt);
@@ -152,11 +162,11 @@ int main(int argc, char ** argv) {
152162
// available logits from the batch and sample the next token until we run out of logits or the sampler
153163
// disagrees with the draft
154164
//
155-
const auto ids = common_sampler_sample_n(smpl, ctx_tgt, batch_tgt);
165+
const auto ids = common_sampler_sample_n(smpl, ctx_tgt, draft);
156166

157167
GGML_ASSERT(ids.size() > 0); // there will always be at least one accepted token
158168

159-
n_past += ids.size();
169+
n_past += ids.size() - 1;
160170
n_drafted += batch_tgt.n_tokens - 1;
161171
n_accept += ids.size() - 1;
162172

@@ -192,16 +202,16 @@ int main(int argc, char ** argv) {
192202
break;
193203
}
194204

195-
LOG_DBG("accepted %d draft tokens, the last target token is: (%d, '%s')\n", (int) ids.size() - 1, id, token_str.c_str());
205+
LOG_DBG("accepted %d/%d draft tokens, the last target token is: (%d, '%s')\n", (int) ids.size() - 1, (int) draft.size(), id, token_str.c_str());
196206

197207
{
198208
LOG_DBG("clear kv cache from any extra tokens, n_past = %d\n", n_past);
199209

200210
llama_kv_cache_seq_rm(ctx_tgt, 0, n_past, -1);
201211
}
202212

203-
prompt_dft.push_back(id_last);
204-
prompt_dft.insert(prompt_dft.end(), ids.begin(), ids.end() - 1);
213+
prompt_tgt.push_back(id_last);
214+
prompt_tgt.insert(prompt_tgt.end(), ids.begin(), ids.end() - 1);
205215

206216
// remember the last accepted token for the next iteration
207217
id_last = id;
@@ -210,6 +220,8 @@ int main(int argc, char ** argv) {
210220

211221
auto t_dec_end = ggml_time_us();
212222

223+
const int n_input = inp.size();
224+
213225
LOG("\n\n");
214226

215227
LOG_INF("encoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_input, (t_enc_end - t_enc_start) / 1e6f, inp.size() / ((t_enc_end - t_enc_start) / 1e6f));

0 commit comments

Comments
 (0)