4
4
#include " common.h"
5
5
#include " sampling.h"
6
6
7
+ #include < cstring>
8
+
9
+ #define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128
10
+ #define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
11
+
7
12
struct common_speculative {
8
13
struct common_speculative_params params;
9
14
10
- llama_batch batch_dft ;
15
+ llama_batch batch ;
11
16
17
+ struct llama_context * ctx;
12
18
struct common_sampler * smpl;
13
19
14
- llama_tokens prompt_last ;
20
+ llama_tokens prompt ;
15
21
};
16
22
17
- struct common_speculative * common_speculative_init (struct common_speculative_params params) {
23
+ struct common_speculative * common_speculative_init (
24
+ struct common_speculative_params params,
25
+ struct llama_context * ctx_dft) {
18
26
auto * result = new common_speculative {
19
- /* .params = */ params,
20
- /* .batch_dft = */ llama_batch_init (llama_n_batch (params.ctx_dft ), 0 , 1 ),
21
- /* .smpl = */ nullptr ,
27
+ /* .params = */ params,
28
+ /* .batch = */ llama_batch_init (llama_n_batch (ctx_dft), 0 , 1 ),
29
+ /* .ctx = */ ctx_dft,
30
+ /* .smpl = */ nullptr ,
31
+ /* .prompt = */ {},
22
32
};
23
33
24
34
// TODO: optimize or pass from outside?
@@ -36,7 +46,7 @@ struct common_speculative * common_speculative_init(struct common_speculative_pa
36
46
COMMON_SAMPLER_TYPE_INFILL,
37
47
};
38
48
39
- result->smpl = common_sampler_init (params. model_dft , sparams);
49
+ result->smpl = common_sampler_init (llama_get_model (ctx_dft) , sparams);
40
50
}
41
51
#else
42
52
{
@@ -49,46 +59,104 @@ struct common_speculative * common_speculative_init(struct common_speculative_pa
49
59
COMMON_SAMPLER_TYPE_TOP_K,
50
60
};
51
61
52
- result->smpl = common_sampler_init(params.model_dft , sparams);
62
+ result->smpl = common_sampler_init(llama_get_model(ctx_dft) , sparams);
53
63
}
54
64
#endif
55
65
56
- result->batch_dft = llama_batch_init (llama_n_batch (params.ctx_dft ), 0 , 1 );
57
-
58
66
return result;
59
67
}
60
68
61
69
void common_speculative_free (struct common_speculative * spec) {
62
70
common_sampler_free (spec->smpl );
63
71
64
- llama_batch_free (spec->batch_dft );
72
+ llama_batch_free (spec->batch );
65
73
66
74
delete spec;
67
75
}
68
76
77
+ bool common_speculative_are_compatible (
78
+ const struct llama_context * ctx_tgt,
79
+ const struct llama_context * ctx_dft) {
80
+ const struct llama_model * model_tgt = llama_get_model (ctx_tgt);
81
+ const struct llama_model * model_dft = llama_get_model (ctx_dft);
82
+
83
+ const bool vocab_type_tgt = llama_vocab_type (model_tgt);
84
+ LOG_DBG (" %s: vocab_type tgt: %d\n " , __func__, vocab_type_tgt);
85
+
86
+ const bool vocab_type_dft = llama_vocab_type (model_dft);
87
+ LOG_DBG (" %s: vocab_type dft: %d\n " , __func__, vocab_type_dft);
88
+
89
+ if (vocab_type_tgt != vocab_type_dft) {
90
+ LOG_ERR (" %s: draft model vocab type must match target model to use speculation but "
91
+ " vocab_type_dft = %d while vocab_type_tgt = %d\n " , __func__, vocab_type_dft, vocab_type_tgt);
92
+ return false ;
93
+ }
94
+
95
+ if (llama_add_bos_token (model_tgt) != llama_add_bos_token (model_dft) ||
96
+ llama_add_eos_token (model_tgt) != llama_add_eos_token (model_dft) ||
97
+ llama_token_bos (model_tgt) != llama_token_bos (model_dft) ||
98
+ llama_token_eos (model_tgt) != llama_token_eos (model_dft)
99
+ ) {
100
+ LOG_ERR (" %s: draft model special tokens must match target model to use speculation\n " , __func__);
101
+ return false ;
102
+ }
103
+
104
+ {
105
+ const int n_vocab_tgt = llama_n_vocab (model_tgt);
106
+ const int n_vocab_dft = llama_n_vocab (model_dft);
107
+
108
+ const int vocab_diff = std::abs (n_vocab_tgt - n_vocab_dft);
109
+
110
+ if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) {
111
+ LOG_ERR (" %s: draft model vocab must closely match target model to use speculation but "
112
+ " target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n " ,
113
+ __func__, n_vocab_tgt, llama_n_vocab (model_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE);
114
+ return false ;
115
+ }
116
+
117
+ for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min (n_vocab_tgt, n_vocab_dft); ++i) {
118
+ const char * token_text_tgt = llama_token_get_text (model_tgt, i);
119
+ const char * token_text_dft = llama_token_get_text (model_dft, i);
120
+ if (std::strcmp (token_text_tgt, token_text_dft) != 0 ) {
121
+ LOG_ERR (" %s: draft model vocab must match target model to use speculation but "
122
+ " token %d content differs - target '%s', draft '%s'\n " , __func__, i,
123
+ common_token_to_piece (ctx_tgt, i).c_str (),
124
+ common_token_to_piece (ctx_dft, i).c_str ());
125
+ return false ;
126
+ }
127
+ }
128
+ }
129
+
130
+ return true ;
131
+ }
132
+
69
133
void common_speculative_add_draft (
70
134
struct common_speculative * spec,
71
135
struct llama_batch & batch_tgt,
72
- const llama_tokens & prompt ,
136
+ const llama_tokens & prompt_tgt ,
73
137
llama_token id_last,
74
138
llama_token n_past_tgt) {
139
+ auto & batch = spec->batch ;
140
+ auto & ctx = spec->ctx ;
141
+ auto & smpl = spec->smpl ;
142
+ auto & prompt = spec->prompt ;
75
143
76
144
int reuse_i = 0 ;
77
145
int reuse_n = 0 ;
78
146
79
- const int n_ctx = llama_n_ctx (spec-> params . ctx_dft ) - spec->params .n_draft ;
147
+ const int n_ctx = llama_n_ctx (ctx ) - spec->params .n_draft ;
80
148
81
- const int i_start = std::max<int >(0 , (int ) prompt .size () - n_ctx);
149
+ const int i_start = std::max<int >(0 , (int ) prompt_tgt .size () - n_ctx);
82
150
83
- for (int i = 0 ; i < (int ) spec-> prompt_last .size (); ++i) {
151
+ for (int i = 0 ; i < (int ) prompt .size (); ++i) {
84
152
int cur = 0 ;
85
- while (i_start + cur < (int ) prompt .size () &&
86
- i + cur < (int ) spec-> prompt_last .size () &&
87
- prompt [i_start + cur] == spec-> prompt_last [i + cur]) {
153
+ while (i_start + cur < (int ) prompt_tgt .size () &&
154
+ i + cur < (int ) prompt .size () &&
155
+ prompt_tgt [i_start + cur] == prompt [i + cur]) {
88
156
cur++;
89
157
}
90
158
91
- if ((cur >= spec->params .n_reuse || prompt .size () <= n_ctx) && cur > reuse_n) {
159
+ if ((cur >= spec->params .n_reuse || prompt_tgt .size () <= n_ctx) && cur > reuse_n) {
92
160
reuse_i = i;
93
161
reuse_n = cur;
94
162
}
@@ -97,59 +165,59 @@ void common_speculative_add_draft(
97
165
LOG_DBG (" %s: reuse_i = %d, reuse_n = %d\n " , __func__, reuse_i, reuse_n);
98
166
99
167
if (reuse_n == 0 ) {
100
- llama_kv_cache_clear (spec-> params . ctx_dft );
168
+ llama_kv_cache_clear (ctx );
101
169
102
- spec-> prompt_last .clear ();
170
+ prompt .clear ();
103
171
} else {
104
- llama_kv_cache_seq_rm (spec-> params . ctx_dft , 0 , 0 , reuse_i);
105
- llama_kv_cache_seq_rm (spec-> params . ctx_dft , 0 , reuse_i + reuse_n, -1 );
106
- llama_kv_cache_seq_add (spec-> params . ctx_dft , 0 , reuse_i, -1 , -reuse_i);
172
+ llama_kv_cache_seq_rm (ctx , 0 , 0 , reuse_i);
173
+ llama_kv_cache_seq_rm (ctx , 0 , reuse_i + reuse_n, -1 );
174
+ llama_kv_cache_seq_add (ctx , 0 , reuse_i, -1 , -reuse_i);
107
175
108
- spec-> prompt_last .erase (spec-> prompt_last .begin (), spec-> prompt_last .begin () + reuse_i);
109
- spec-> prompt_last .erase (spec-> prompt_last .begin () + reuse_n, spec-> prompt_last .end ());
176
+ prompt .erase (prompt .begin (), prompt .begin () + reuse_i);
177
+ prompt .erase (prompt .begin () + reuse_n, prompt .end ());
110
178
}
111
179
112
- common_batch_clear (spec-> batch_dft );
180
+ common_batch_clear (batch );
113
181
114
- for (int i = i_start + reuse_n; i < (int ) prompt .size (); ++i) {
115
- // LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt [i]);
116
- common_batch_add (spec-> batch_dft , prompt [i], i - i_start, { 0 }, false );
182
+ for (int i = i_start + reuse_n; i < (int ) prompt_tgt .size (); ++i) {
183
+ // LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt_tgt [i]);
184
+ common_batch_add (batch, prompt_tgt [i], i - i_start, { 0 }, false );
117
185
118
- spec-> prompt_last .push_back (prompt [i]);
186
+ prompt .push_back (prompt_tgt [i]);
119
187
}
120
188
121
- const llama_pos n_past = prompt .size () - i_start;
189
+ const llama_pos n_past = prompt_tgt .size () - i_start;
122
190
123
191
LOG_DBG (" %s: n_past = %d\n " , __func__, n_past);
124
192
125
- if (spec-> batch_dft .n_tokens > 0 ) {
126
- LOG_DBG (" %s: draft batch: %s\n " , __func__, string_from (spec-> params . ctx_dft , spec-> batch_dft ).c_str ());
193
+ if (batch .n_tokens > 0 ) {
194
+ LOG_DBG (" %s: draft batch: %s\n " , __func__, string_from (ctx, batch ).c_str ());
127
195
128
- llama_decode (spec-> params . ctx_dft , spec-> batch_dft );
196
+ llama_decode (ctx, batch );
129
197
}
130
198
131
- common_batch_clear (spec-> batch_dft );
132
- common_batch_add (spec-> batch_dft , id_last, n_past, { 0 }, true );
199
+ common_batch_clear (batch );
200
+ common_batch_add (batch , id_last, n_past, { 0 }, true );
133
201
134
- spec-> prompt_last .push_back (id_last);
202
+ prompt .push_back (id_last);
135
203
136
- LOG_DBG (" %s: prompt_last: %s\n " , __func__, string_from (spec-> params . ctx_dft , spec-> prompt_last ).c_str ());
204
+ LOG_DBG (" %s: prompt_last: %s\n " , __func__, string_from (ctx, prompt ).c_str ());
137
205
138
- llama_decode (spec-> params . ctx_dft , spec-> batch_dft );
206
+ llama_decode (ctx, batch );
139
207
140
- common_sampler_reset (spec-> smpl );
208
+ common_sampler_reset (smpl);
141
209
142
210
// sample n_draft tokens from the draft model
143
211
for (int i = 0 ; i < spec->params .n_draft ; ++i) {
144
- common_batch_clear (spec-> batch_dft );
212
+ common_batch_clear (batch );
145
213
146
- common_sampler_sample (spec-> smpl , spec-> params . ctx_dft , 0 , true );
214
+ common_sampler_sample (smpl, ctx , 0 , true );
147
215
148
- const auto * cur_p = common_sampler_get_candidates (spec-> smpl );
216
+ const auto * cur_p = common_sampler_get_candidates (smpl);
149
217
150
218
for (int k = 0 ; k < std::min (3 , (int ) cur_p->size ); ++k) {
151
219
LOG_DBG (" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n " ,
152
- k, i, cur_p->data [k].id , cur_p->data [k].p , common_token_to_piece (spec-> params . ctx_dft , cur_p->data [k].id ).c_str ());
220
+ k, i, cur_p->data [k].id , cur_p->data [k].p , common_token_to_piece (ctx , cur_p->data [k].id ).c_str ());
153
221
}
154
222
155
223
// add drafted token for each sequence
@@ -160,20 +228,20 @@ void common_speculative_add_draft(
160
228
break ;
161
229
}
162
230
163
- common_sampler_accept (spec-> smpl , id, true );
231
+ common_sampler_accept (smpl, id, true );
164
232
165
233
common_batch_add (batch_tgt, id, n_past_tgt + i, { 0 }, true );
166
234
167
235
if (batch_tgt.n_tokens > spec->params .n_draft ) {
168
236
break ;
169
237
}
170
238
171
- common_batch_add (spec-> batch_dft , id, n_past + i + 1 , { 0 }, true );
239
+ common_batch_add (batch , id, n_past + i + 1 , { 0 }, true );
172
240
173
241
// evaluate the drafted tokens on the draft model
174
- llama_decode (spec-> params . ctx_dft , spec-> batch_dft );
242
+ llama_decode (ctx, batch );
175
243
176
- spec-> prompt_last .push_back (id);
244
+ prompt .push_back (id);
177
245
}
178
246
179
247
// don't waste time on small batches
0 commit comments