Skip to content

Commit eb7f0ed

Browse files
ggerganovarthw
authored andcommitted
server : fix parallel speculative decoding (ggml-org#10513)
ggml-ci
1 parent 315fbd2 commit eb7f0ed

File tree

1 file changed

+31
-32
lines changed

1 file changed

+31
-32
lines changed

examples/server/server.cpp

Lines changed: 31 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -2267,50 +2267,49 @@ struct server_context {
22672267
continue; // continue loop of slots
22682268
}
22692269

2270-
llama_token id;
2270+
llama_token id = common_sampler_sample(slot.smpl, ctx, slot.i_batch - i);
22712271

2272-
{
2273-
completion_token_output result;
2274-
2275-
id = common_sampler_sample(slot.smpl, ctx, slot.i_batch - i);
2272+
slot.i_batch = -1;
22762273

2277-
slot.i_batch = -1;
2274+
common_sampler_accept(slot.smpl, id, true);
22782275

2279-
common_sampler_accept(slot.smpl, id, true);
2280-
2281-
slot.n_decoded += 1;
2282-
if (slot.n_decoded == 1) {
2283-
slot.t_start_generation = ggml_time_us();
2284-
slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3;
2285-
metrics.on_prompt_eval(slot);
2286-
}
2276+
slot.n_decoded += 1;
2277+
if (slot.n_decoded == 1) {
2278+
slot.t_start_generation = ggml_time_us();
2279+
slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3;
2280+
metrics.on_prompt_eval(slot);
2281+
}
22872282

2288-
result.tok = id;
2283+
completion_token_output result;
2284+
result.tok = id;
22892285

2290-
const auto * cur_p = common_sampler_get_candidates(slot.smpl);
2286+
const auto * cur_p = common_sampler_get_candidates(slot.smpl);
22912287

2292-
for (size_t i = 0; i < (size_t) slot.params.sampling.n_probs; ++i) {
2293-
result.probs.push_back({
2294-
cur_p->data[i].id,
2295-
i >= cur_p->size ? 0.0f : cur_p->data[i].p,
2296-
});
2297-
}
2288+
for (size_t i = 0; i < (size_t) slot.params.sampling.n_probs; ++i) {
2289+
result.probs.push_back({
2290+
cur_p->data[i].id,
2291+
i >= cur_p->size ? 0.0f : cur_p->data[i].p,
2292+
});
2293+
}
22982294

2299-
if (!process_token(result, slot)) {
2300-
// release slot because of stop condition
2301-
slot.release();
2302-
slot.print_timings();
2303-
send_final_response(slot);
2304-
metrics.on_prediction(slot);
2305-
continue;
2306-
}
2295+
if (!process_token(result, slot)) {
2296+
// release slot because of stop condition
2297+
slot.release();
2298+
slot.print_timings();
2299+
send_final_response(slot);
2300+
metrics.on_prediction(slot);
2301+
continue;
23072302
}
2303+
}
23082304

2309-
// check if the slot supports speculative decoding
2310-
if (!slot.can_speculate()) {
2305+
// do speculative decoding
2306+
for (auto & slot : slots) {
2307+
if (!slot.is_processing() || !slot.can_speculate()) {
23112308
continue;
23122309
}
23132310

2311+
llama_token id = slot.sampled;
2312+
23142313
struct common_speculative_params params_spec;
23152314
params_spec.n_draft = slot.params.speculative.n_max;
23162315
params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.params.speculative.n_max;

0 commit comments

Comments
 (0)