Skip to content

Commit f27ddc5

Browse files
committed
speculative : add --draft-min CLI arg
1 parent 0d4d0c1 commit f27ddc5

File tree

3 files changed

+12
-7
lines changed

3 files changed

+12
-7
lines changed

common/arg.cpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -609,7 +609,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
609609
[](common_params & params, int value) {
610610
params.n_draft = value;
611611
}
612-
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_LOOKUP}));
612+
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_LOOKUP, LLAMA_EXAMPLE_SERVER}));
613+
add_opt(common_arg(
614+
{"--draft-min"}, "N",
615+
string_format("minimum number of draft tokens to use for speculative decoding (default: %d)", params.n_draft_min),
616+
[](common_params & params, int value) {
617+
params.n_draft_min = value;
618+
}
619+
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_LOOKUP, LLAMA_EXAMPLE_SERVER}));
613620
add_opt(common_arg(
614621
{"-ps", "--p-split"}, "N",
615622
string_format("speculative decoding split probability (default: %.1f)", (double)params.p_split),
@@ -1454,7 +1461,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
14541461
fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n");
14551462
}
14561463
}
1457-
).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
1464+
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}));
14581465
add_opt(common_arg(
14591466
{"-sm", "--split-mode"}, "{none,layer,row}",
14601467
"how to split the model across multiple GPUs, one of:\n"
@@ -1599,7 +1606,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
15991606
[](common_params & params, const std::string & value) {
16001607
params.model_draft = value;
16011608
}
1602-
).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
1609+
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}));
16031610
add_opt(common_arg(
16041611
{"-mu", "--model-url"}, "MODEL_URL",
16051612
"model download url (default: unused)",

common/common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@ struct common_params {
162162
int32_t n_ubatch = 512; // physical batch size for prompt processing (must be >=32 to use BLAS)
163163
int32_t n_keep = 0; // number of tokens to keep from initial prompt
164164
int32_t n_draft = 5; // number of tokens to draft during speculative decoding
165+
int32_t n_draft_min = 0; // minimum number of draft tokens to use for speculative decoding
165166
int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited)
166167
int32_t n_parallel = 1; // number of parallel sequences to decode
167168
int32_t n_sequences = 1; // number of sequences to decode

examples/speculative-simple/speculative-simple.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,6 @@
1313
int main(int argc, char ** argv) {
1414
common_params params;
1515

16-
// minimum size of the draft to use
17-
const int n_min = 5;
18-
1916
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_SPECULATIVE)) {
2017
return 1;
2118
}
@@ -142,7 +139,7 @@ int main(int argc, char ** argv) {
142139
// evaluate the target model on [id_last, draft0, draft1, ..., draftN-1]
143140
{
144141
// do not waste time on small drafts
145-
if (draft.size() < n_min) {
142+
if (draft.size() < params.n_draft_min) {
146143
draft.clear();
147144
}
148145

0 commit comments

Comments
 (0)