Skip to content

Commit 1da7b76

Browse files
authored
server : fix speculative decoding with context shift (#10641)
* server : fix speculative decoding with context shift ggml-ci * server : take into account speculative limits ggml-ci * server : add tests
1 parent 59f4db1 commit 1da7b76

File tree

2 files changed

+58
-2
lines changed

2 files changed

+58
-2
lines changed

examples/server/server.cpp

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -921,6 +921,8 @@ struct server_context {
921921
slot.params.speculative.p_min = json_value(data, "speculative.p_min", defaults.speculative.p_min);
922922

923923
slot.params.speculative.n_min = std::min(slot.params.speculative.n_max, slot.params.speculative.n_min);
924+
slot.params.speculative.n_min = std::max(slot.params.speculative.n_min, 2);
925+
slot.params.speculative.n_max = std::max(slot.params.speculative.n_max, 0);
924926

925927
if (slot.params.sampling.dry_base < 1.0f) {
926928
slot.params.sampling.dry_base = defaults.sampling.dry_base;
@@ -2322,17 +2324,38 @@ struct server_context {
23222324
continue;
23232325
}
23242326

2327+
// determine the max draft that fits the current slot state
2328+
int n_draft_max = slot.params.speculative.n_max;
2329+
2330+
// note: n_past is not yet increased for the `id` token sampled above
2331+
// also, need to leave space for 1 extra token to allow context shifts
2332+
n_draft_max = std::min(n_draft_max, slot.n_ctx - slot.n_past - 2);
2333+
2334+
if (slot.n_remaining > 0) {
2335+
n_draft_max = std::min(n_draft_max, slot.n_remaining - 1);
2336+
}
2337+
2338+
SLT_DBG(slot, "max possible draft: %d\n", n_draft_max);
2339+
2340+
if (n_draft_max < slot.params.speculative.n_min) {
2341+
SLT_DBG(slot, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, slot.params.speculative.n_min);
2342+
2343+
continue;
2344+
}
2345+
23252346
llama_token id = slot.sampled;
23262347

23272348
struct common_speculative_params params_spec;
2328-
params_spec.n_draft = slot.params.speculative.n_max;
2349+
params_spec.n_draft = n_draft_max;
23292350
params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.params.speculative.n_max;
23302351
params_spec.p_min = slot.params.speculative.p_min;
23312352

23322353
llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, slot.cache_tokens, id);
23332354

23342355
// ignore small drafts
23352356
if (slot.params.speculative.n_min > (int) draft.size()) {
2357+
SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.params.speculative.n_min);
2358+
23362359
continue;
23372360
}
23382361

@@ -2344,6 +2367,8 @@ struct server_context {
23442367
common_batch_add(slot.batch_spec, draft[i], slot.n_past + 1 + i, { slot.id }, true);
23452368
}
23462369

2370+
SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.n_tokens);
2371+
23472372
llama_decode(ctx, slot.batch_spec);
23482373

23492374
// the accepted tokens from the speculation
@@ -2372,7 +2397,7 @@ struct server_context {
23722397
}
23732398
}
23742399

2375-
SRV_DBG("accepted %d/%d draft tokens\n", (int) ids.size() - 1, (int) draft.size());
2400+
SLT_DBG(slot, "accepted %d/%d draft tokens, new n_past = %d\n", (int) ids.size() - 1, (int) draft.size(), slot.n_past);
23762401
}
23772402
}
23782403

examples/server/tests/unit/test_speculative.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,37 @@ def test_different_draft_min_draft_max():
8282
last_content = res.body["content"]
8383

8484

85+
def test_slot_ctx_not_exceeded():
86+
global server
87+
server.n_ctx = 64
88+
server.start()
89+
res = server.make_request("POST", "/completion", data={
90+
"prompt": "Hello " * 56,
91+
"temperature": 0.0,
92+
"top_k": 1,
93+
"speculative.p_min": 0.0,
94+
})
95+
assert res.status_code == 200
96+
assert len(res.body["content"]) > 0
97+
98+
99+
def test_with_ctx_shift():
100+
global server
101+
server.n_ctx = 64
102+
server.start()
103+
res = server.make_request("POST", "/completion", data={
104+
"prompt": "Hello " * 56,
105+
"temperature": 0.0,
106+
"top_k": 1,
107+
"n_predict": 64,
108+
"speculative.p_min": 0.0,
109+
})
110+
assert res.status_code == 200
111+
assert len(res.body["content"]) > 0
112+
assert res.body["tokens_predicted"] == 64
113+
assert res.body["truncated"] == True
114+
115+
85116
@pytest.mark.parametrize("n_slots,n_requests", [
86117
(1, 2),
87118
(2, 2),

0 commit comments

Comments
 (0)