@@ -921,6 +921,8 @@ struct server_context {
921
921
slot.params .speculative .p_min = json_value (data, " speculative.p_min" , defaults.speculative .p_min );
922
922
923
923
slot.params .speculative .n_min = std::min (slot.params .speculative .n_max , slot.params .speculative .n_min );
924
+ slot.params .speculative .n_min = std::max (slot.params .speculative .n_min , 2 );
925
+ slot.params .speculative .n_max = std::max (slot.params .speculative .n_max , 0 );
924
926
925
927
if (slot.params .sampling .dry_base < 1 .0f ) {
926
928
slot.params .sampling .dry_base = defaults.sampling .dry_base ;
@@ -2322,17 +2324,38 @@ struct server_context {
2322
2324
continue ;
2323
2325
}
2324
2326
2327
+ // determine the max draft that fits the current slot state
2328
+ int n_draft_max = slot.params .speculative .n_max ;
2329
+
2330
+ // note: n_past is not yet increased for the `id` token sampled above
2331
+ // also, need to leave space for 1 extra token to allow context shifts
2332
+ n_draft_max = std::min (n_draft_max, slot.n_ctx - slot.n_past - 2 );
2333
+
2334
+ if (slot.n_remaining > 0 ) {
2335
+ n_draft_max = std::min (n_draft_max, slot.n_remaining - 1 );
2336
+ }
2337
+
2338
+ SLT_DBG (slot, " max possible draft: %d\n " , n_draft_max);
2339
+
2340
+ if (n_draft_max < slot.params .speculative .n_min ) {
2341
+ SLT_DBG (slot, " the max possible draft is too small: %d < %d - skipping speculative decoding\n " , n_draft_max, slot.params .speculative .n_min );
2342
+
2343
+ continue ;
2344
+ }
2345
+
2325
2346
llama_token id = slot.sampled ;
2326
2347
2327
2348
struct common_speculative_params params_spec;
2328
- params_spec.n_draft = slot. params . speculative . n_max ;
2349
+ params_spec.n_draft = n_draft_max ;
2329
2350
params_spec.n_reuse = llama_n_ctx (slot.ctx_dft ) - slot.params .speculative .n_max ;
2330
2351
params_spec.p_min = slot.params .speculative .p_min ;
2331
2352
2332
2353
llama_tokens draft = common_speculative_gen_draft (slot.spec , params_spec, slot.cache_tokens , id);
2333
2354
2334
2355
// ignore small drafts
2335
2356
if (slot.params .speculative .n_min > (int ) draft.size ()) {
2357
+ SLT_DBG (slot, " ignoring small draft: %d < %d\n " , (int ) draft.size (), slot.params .speculative .n_min );
2358
+
2336
2359
continue ;
2337
2360
}
2338
2361
@@ -2344,6 +2367,8 @@ struct server_context {
2344
2367
common_batch_add (slot.batch_spec , draft[i], slot.n_past + 1 + i, { slot.id }, true );
2345
2368
}
2346
2369
2370
+ SLT_DBG (slot, " decoding speculative batch, size = %d\n " , slot.batch_spec .n_tokens );
2371
+
2347
2372
llama_decode (ctx, slot.batch_spec );
2348
2373
2349
2374
// the accepted tokens from the speculation
@@ -2372,7 +2397,7 @@ struct server_context {
2372
2397
}
2373
2398
}
2374
2399
2375
- SRV_DBG ( " accepted %d/%d draft tokens\n " , (int ) ids.size () - 1 , (int ) draft.size ());
2400
+ SLT_DBG (slot, " accepted %d/%d draft tokens, new n_past = %d \n " , (int ) ids.size () - 1 , (int ) draft.size (), slot. n_past );
2376
2401
}
2377
2402
}
2378
2403
0 commit comments