Skip to content

Commit 0f878a6

Browse files
committed
speculative : manage context in common_speculative
ggml-ci
1 parent fe043ff commit 0f878a6

File tree

9 files changed

+188
-144
lines changed

9 files changed

+188
-144
lines changed

common/common.cpp

Lines changed: 66 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -536,12 +536,12 @@ std::string string_from(const struct llama_context * ctx, const struct llama_bat
536536
[](const unsigned char c) { return !std::isprint(c); }),
537537
detokenized.end());
538538

539-
buf << "\n" << std::to_string(i)
540-
<< ":token '" << detokenized << "'"
541-
<< ":pos " << std::to_string(batch.pos[i])
542-
<< ":n_seq_id " << std::to_string(batch.n_seq_id[i])
543-
<< ":seq_id " << std::to_string(batch.seq_id[i][0])
544-
<< ":logits " << std::to_string(batch.logits[i]);
539+
buf << "\n" << std::to_string(i)
540+
<< ", token '" << detokenized << "'"
541+
<< ", pos " << std::to_string(batch.pos[i])
542+
<< ", n_seq_id " << std::to_string(batch.n_seq_id[i])
543+
<< ", seq_id " << std::to_string(batch.seq_id[i][0])
544+
<< ", logits " << std::to_string(batch.logits[i]);
545545
}
546546

547547
buf << " ]";
@@ -1490,6 +1490,66 @@ void common_batch_add(
14901490
batch.n_tokens++;
14911491
}
14921492

1493+
//
1494+
// Token utils
1495+
//
1496+
1497+
size_t common_lcp(const llama_tokens & a, const llama_tokens & b) {
1498+
size_t i;
1499+
for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {}
1500+
1501+
return i;
1502+
}
1503+
1504+
size_t common_lcs(const llama_tokens & a, const llama_tokens & b) {
1505+
// check for empty sequences
1506+
if (a.empty() || b.empty()) {
1507+
return 0;
1508+
}
1509+
1510+
// get the lengths of the input sequences
1511+
size_t a_len = a.size();
1512+
size_t b_len = b.size();
1513+
1514+
// initialize the maximum length of the longest common subsequence (LCS)
1515+
size_t max_length = 0;
1516+
1517+
// use two rows instead of a 2D matrix to optimize space
1518+
std::vector<size_t> prev_row(b_len + 1, 0);
1519+
std::vector<size_t> curr_row(b_len + 1, 0);
1520+
1521+
// iterate through the elements of a
1522+
for (size_t i = 1; i <= a_len; i++) {
1523+
// iterate through the elements of b
1524+
for (size_t j = 1; j <= b_len; j++) {
1525+
// if elements at the current positions match
1526+
if (a[i - 1] == b[j - 1]) {
1527+
// if it's the first element of either sequences, set LCS length to 1
1528+
if (i == 1 || j == 1) {
1529+
curr_row[j] = 1;
1530+
} else {
1531+
// increment LCS length by 1 compared to the previous element
1532+
curr_row[j] = prev_row[j - 1] + 1;
1533+
}
1534+
1535+
// update max_length if necessary
1536+
if (curr_row[j] > max_length) {
1537+
max_length = curr_row[j];
1538+
}
1539+
} else {
1540+
// reset LCS length if elements don't match
1541+
curr_row[j] = 0;
1542+
}
1543+
}
1544+
1545+
// update the previous row for the next iteration
1546+
prev_row = curr_row;
1547+
}
1548+
1549+
// return the maximum length of the LCS
1550+
return max_length;
1551+
}
1552+
14931553
//
14941554
// Vocab utils
14951555
//

common/common.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ struct common_lora_adapter_container : common_lora_adapter_info {
3333
struct llama_lora_adapter * adapter;
3434
};
3535

36+
using llama_tokens = std::vector<llama_token>;
37+
3638
// build info
3739
extern int LLAMA_BUILD_NUMBER;
3840
extern char const * LLAMA_COMMIT;
@@ -461,7 +463,9 @@ struct llama_model * common_load_model_from_hf(const char * repo, const char * f
461463
// clear LoRA adapters from context, then apply new list of adapters
462464
void common_lora_adapters_apply(struct llama_context * ctx, std::vector<common_lora_adapter_container> & lora_adapters);
463465

466+
//
464467
// Batch utils
468+
//
465469

466470
void common_batch_clear(struct llama_batch & batch);
467471

@@ -472,6 +476,16 @@ void common_batch_add(
472476
const std::vector<llama_seq_id> & seq_ids,
473477
bool logits);
474478

479+
//
480+
// Token utils
481+
//
482+
483+
// longest common prefix
484+
size_t common_lcp(const llama_tokens & a, const llama_tokens & b);
485+
486+
// longet common subsequence
487+
size_t common_lcs(const llama_tokens & a, const llama_tokens & b);
488+
475489
//
476490
// Vocab utils
477491
//

common/sampling.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,28 @@ std::vector<llama_token> common_sampler_sample_n(struct common_sampler * gsmpl,
342342
return result;
343343
}
344344

345+
std::vector<llama_token> common_sampler_sample_n(struct common_sampler * gsmpl, struct llama_context * ctx, const struct llama_batch & batch, bool grammar_first) {
346+
std::vector<int> idxs;
347+
idxs.reserve(batch.n_tokens);
348+
349+
std::vector<llama_token> draft;
350+
draft.reserve(batch.n_tokens);
351+
352+
for (int i = 0; i < batch.n_tokens; i++) {
353+
if (batch.logits[i] == 0) {
354+
continue;
355+
}
356+
357+
if (idxs.size() > 0) {
358+
GGML_ASSERT(batch.pos[idxs.back()] + 1 == batch.pos[i]);
359+
draft.push_back(batch.token[i]);
360+
}
361+
idxs.push_back(i);
362+
}
363+
364+
return common_sampler_sample_n(gsmpl, ctx, idxs, draft, grammar_first);
365+
}
366+
345367
uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) {
346368
return llama_sampler_get_seed(gsmpl->chain);
347369
}

common/sampling.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
7373
//
7474
std::vector<llama_token> common_sampler_sample_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const std::vector<llama_token> & draft, bool grammar_first = false);
7575

76+
std::vector<llama_token> common_sampler_sample_n(struct common_sampler * gsmpl, struct llama_context * ctx, const struct llama_batch & batch, bool grammar_first = false);
77+
7678
uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl);
7779

7880
// helpers

common/speculative.cpp

Lines changed: 67 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -11,22 +11,18 @@ struct common_speculative {
1111

1212
struct common_sampler * smpl;
1313

14-
std::vector<int> i_batch_tgt;
15-
16-
std::vector<llama_token> tokens;
14+
llama_tokens prompt_last;
1715
};
1816

1917
struct common_speculative * common_speculative_init(struct common_speculative_params params) {
2018
auto * result = new common_speculative {
2119
/* .params = */ params,
2220
/* .batch_dft = */ llama_batch_init(llama_n_batch(params.ctx_dft), 0, 1),
2321
/* .smpl = */ nullptr,
24-
/* .i_batch_tgt = */ {},
25-
/* .tokens = */ {},
2622
};
2723

2824
// TODO: optimize or pass from outside?
29-
#if 0
25+
#if 1
3026
{
3127
common_sampler_params sparams;
3228
sparams.no_perf = false;
@@ -70,30 +66,79 @@ void common_speculative_free(struct common_speculative * spec) {
7066
delete spec;
7167
}
7268

73-
void common_speculative_set_prompt(struct common_speculative * spec, llama_token * tokens, int32_t n_tokens) {
74-
llama_kv_cache_clear(spec->params.ctx_dft);
75-
76-
// TODO: error handling
77-
llama_decode(spec->params.ctx_dft, llama_batch_get_one(tokens, n_tokens));
78-
}
79-
8069
void common_speculative_add_draft(
8170
struct common_speculative * spec,
8271
struct llama_batch & batch_tgt,
72+
const llama_tokens & prompt,
8373
llama_token id_last,
84-
int n_past) {
85-
spec->tokens.clear();
74+
llama_token n_past_tgt) {
8675

87-
spec->i_batch_tgt.clear();
88-
spec->i_batch_tgt.push_back(0);
76+
int reuse_i = 0;
77+
int reuse_n = 0;
8978

90-
common_sampler_reset(spec->smpl);
79+
const int n_ctx = llama_n_ctx(spec->params.ctx_dft) - spec->params.n_draft;
80+
81+
const int i_start = std::max<int>(0, (int) prompt.size() - n_ctx);
82+
83+
for (int i = 0; i < (int) spec->prompt_last.size(); ++i) {
84+
int cur = 0;
85+
while (i_start + cur < (int) prompt.size() &&
86+
i + cur < (int) spec->prompt_last.size() &&
87+
prompt[i_start + cur] == spec->prompt_last[i + cur]) {
88+
cur++;
89+
}
90+
91+
if ((cur >= spec->params.n_reuse || prompt.size() <= n_ctx) && cur > reuse_n) {
92+
reuse_i = i;
93+
reuse_n = cur;
94+
}
95+
}
96+
97+
LOG_DBG("%s: reuse_i = %d, reuse_n = %d\n", __func__, reuse_i, reuse_n);
98+
99+
if (reuse_n == 0) {
100+
llama_kv_cache_clear(spec->params.ctx_dft);
101+
102+
spec->prompt_last.clear();
103+
} else {
104+
llama_kv_cache_seq_rm (spec->params.ctx_dft, 0, 0, reuse_i);
105+
llama_kv_cache_seq_rm (spec->params.ctx_dft, 0, reuse_i + reuse_n, -1);
106+
llama_kv_cache_seq_add(spec->params.ctx_dft, 0, reuse_i, -1, -reuse_i);
107+
108+
spec->prompt_last.erase(spec->prompt_last.begin(), spec->prompt_last.begin() + reuse_i);
109+
spec->prompt_last.erase(spec->prompt_last.begin() + reuse_n, spec->prompt_last.end());
110+
}
111+
112+
common_batch_clear(spec->batch_dft);
113+
114+
for (int i = i_start + reuse_n; i < (int) prompt.size(); ++i) {
115+
//LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt[i]);
116+
common_batch_add(spec->batch_dft, prompt[i], i - i_start, { 0 }, false);
117+
118+
spec->prompt_last.push_back(prompt[i]);
119+
}
120+
121+
const llama_pos n_past = prompt.size() - i_start;
122+
123+
LOG_DBG("%s: n_past = %d\n", __func__, n_past);
124+
125+
if (spec->batch_dft.n_tokens > 0) {
126+
LOG_DBG("%s: draft batch: %s\n", __func__, string_from(spec->params.ctx_dft, spec->batch_dft).c_str());
127+
128+
llama_decode(spec->params.ctx_dft, spec->batch_dft);
129+
}
91130

92131
common_batch_clear(spec->batch_dft);
93132
common_batch_add (spec->batch_dft, id_last, n_past, { 0 }, true);
94133

134+
spec->prompt_last.push_back(id_last);
135+
136+
LOG_DBG("%s: prompt_last: %s\n", __func__, string_from(spec->params.ctx_dft, spec->prompt_last).c_str());
137+
95138
llama_decode(spec->params.ctx_dft, spec->batch_dft);
96139

140+
common_sampler_reset(spec->smpl);
141+
97142
// sample n_draft tokens from the draft model
98143
for (int i = 0; i < spec->params.n_draft; ++i) {
99144
common_batch_clear(spec->batch_dft);
@@ -111,18 +156,13 @@ void common_speculative_add_draft(
111156
const llama_token id = cur_p->data[0].id;
112157

113158
// only collect very high-confidence draft tokens
114-
if (cur_p->data[0].p < 0.75 && spec->tokens.size() >= 0) {
159+
if (cur_p->data[0].p < spec->params.p_min) {
115160
break;
116161
}
117162

118163
common_sampler_accept(spec->smpl, id, true);
119164

120-
spec->tokens.push_back(id);
121-
122-
// add unique drafted tokens to the target batch
123-
spec->i_batch_tgt.push_back(batch_tgt.n_tokens);
124-
125-
common_batch_add(batch_tgt, id, n_past + i + 1, { 0 }, true);
165+
common_batch_add(batch_tgt, id, n_past_tgt + i, { 0 }, true);
126166

127167
if (batch_tgt.n_tokens > spec->params.n_draft) {
128168
break;
@@ -132,23 +172,13 @@ void common_speculative_add_draft(
132172

133173
// evaluate the drafted tokens on the draft model
134174
llama_decode(spec->params.ctx_dft, spec->batch_dft);
175+
176+
spec->prompt_last.push_back(id);
135177
}
136178

137179
// don't waste time on small batches
138180
// TODO: do not evaluate the draft model for that many rounds
139181
if (batch_tgt.n_tokens < spec->params.n_min) {
140182
batch_tgt.n_tokens = 1;
141-
spec->tokens.resize(0);
142-
spec->i_batch_tgt.resize(1);
143183
}
144-
145-
// print current draft sequences
146-
LOG_DBG("draft %s\n", string_from(spec->params.ctx_dft, spec->tokens).c_str());
147-
}
148-
149-
std::vector<llama_token> common_speculative_sample(
150-
struct common_speculative * spec,
151-
struct common_sampler * smpl,
152-
struct llama_context * ctx_tgt) {
153-
return common_sampler_sample_n(smpl, ctx_tgt, spec->i_batch_tgt, spec->tokens);
154184
}

common/speculative.h

Lines changed: 6 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
#pragma once
22

33
#include "llama.h"
4-
5-
#include <vector>
4+
#include "common.h"
65

76
struct common_speculative;
87

98
struct common_speculative_params {
109
int n_draft = 16;
1110
int n_min = 5; // do not add drafts smaller than this, TODO: leave this to user?
11+
int n_reuse = 256;
12+
13+
float p_min = 0.9f;
1214

1315
struct llama_model * model_dft = nullptr;
1416

@@ -19,28 +21,11 @@ struct common_speculative * common_speculative_init(struct common_speculative_pa
1921

2022
void common_speculative_free(struct common_speculative * spec);
2123

22-
// TODO: remove
23-
void common_speculative_set_prompt(struct common_speculative * spec, llama_token * tokens, int32_t n_tokens);
24-
2524
// sample up to n_draft tokens and add them to the batch using the draft model
2625
//
27-
// TODO: change to:
28-
//
29-
// void common_speculative_add_draft(
30-
// struct common_speculative * spec,
31-
// struct llama_batch & batch_tgt,
32-
// llama_token * tokens,
33-
// int32_t n_tokens);
34-
//
35-
// and update the internal logic to compute only the new tokens
36-
//
3726
void common_speculative_add_draft(
3827
struct common_speculative * spec,
3928
struct llama_batch & batch_tgt,
29+
const llama_tokens & prompt,
4030
llama_token id_last,
41-
int n_past);
42-
43-
std::vector<llama_token> common_speculative_sample(
44-
struct common_speculative * spec,
45-
struct common_sampler * smpl,
46-
struct llama_context * ctx_tgt);
31+
llama_token n_past_tgt);

examples/server/server.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -743,7 +743,7 @@ struct server_context {
743743
}
744744

745745
// length of the Longest Common Subsequence between the current slot's prompt and the input prompt
746-
int cur_lcs_len = longest_common_subsequence(slot.cache_tokens, task.prompt_tokens);
746+
int cur_lcs_len = common_lcs(slot.cache_tokens, task.prompt_tokens);
747747

748748
// fraction of the common subsequence length compared to the current slot's prompt length
749749
float cur_similarity = static_cast<float>(cur_lcs_len) / static_cast<int>(slot.cache_tokens.size());
@@ -1960,7 +1960,7 @@ struct server_context {
19601960

19611961
if (slot.params.cache_prompt) {
19621962
// reuse any previously computed tokens that are common with the new prompt
1963-
slot.n_past = longest_common_prefix(slot.cache_tokens, prompt_tokens);
1963+
slot.n_past = common_lcp(slot.cache_tokens, prompt_tokens);
19641964

19651965
// reuse chunks from the cached prompt by shifting their KV cache in the new position
19661966
if (params.n_cache_reuse > 0) {

0 commit comments

Comments
 (0)