@@ -161,10 +161,17 @@ static const char * split_mode_str(llama_split_mode mode) {
161
161
}
162
162
}
163
163
164
+ static std::string pair_str (const std::pair<int , int > & p) {
165
+ static char buf[32 ];
166
+ snprintf (buf, sizeof (buf), " %d,%d" , p.first , p.second );
167
+ return buf;
168
+ }
169
+
164
170
struct cmd_params {
165
171
std::vector<std::string> model;
166
172
std::vector<int > n_prompt;
167
173
std::vector<int > n_gen;
174
+ std::vector<std::pair<int , int >> n_pg;
168
175
std::vector<int > n_batch;
169
176
std::vector<int > n_ubatch;
170
177
std::vector<ggml_type> type_k;
@@ -188,6 +195,7 @@ static const cmd_params cmd_params_defaults = {
188
195
/* model */ {" models/7B/ggml-model-q4_0.gguf" },
189
196
/* n_prompt */ {512 },
190
197
/* n_gen */ {128 },
198
+ /* n_pg */ {{512 , 128 }},
191
199
/* n_batch */ {2048 },
192
200
/* n_ubatch */ {512 },
193
201
/* type_k */ {GGML_TYPE_F16},
@@ -215,10 +223,11 @@ static void print_usage(int /* argc */, char ** argv) {
215
223
printf (" -m, --model <filename> (default: %s)\n " , join (cmd_params_defaults.model , " ," ).c_str ());
216
224
printf (" -p, --n-prompt <n> (default: %s)\n " , join (cmd_params_defaults.n_prompt , " ," ).c_str ());
217
225
printf (" -n, --n-gen <n> (default: %s)\n " , join (cmd_params_defaults.n_gen , " ," ).c_str ());
226
+ printf (" -pg <pp,tg> (default: %s)\n " , join (transform_to_str (cmd_params_defaults.n_pg , pair_str), " ," ).c_str ());
218
227
printf (" -b, --batch-size <n> (default: %s)\n " , join (cmd_params_defaults.n_batch , " ," ).c_str ());
219
- printf (" -ub N , --ubatch-size <n> (default: %s)\n " , join (cmd_params_defaults.n_ubatch , " ," ).c_str ());
220
- printf (" -ctk <t> , --cache-type-k <t> (default: %s)\n " , join (transform_to_str (cmd_params_defaults.type_k , ggml_type_name), " ," ).c_str ());
221
- printf (" -ctv <t> , --cache-type-v <t> (default: %s)\n " , join (transform_to_str (cmd_params_defaults.type_v , ggml_type_name), " ," ).c_str ());
228
+ printf (" -ub, --ubatch-size <n> (default: %s)\n " , join (cmd_params_defaults.n_ubatch , " ," ).c_str ());
229
+ printf (" -ctk, --cache-type-k <t> (default: %s)\n " , join (transform_to_str (cmd_params_defaults.type_k , ggml_type_name), " ," ).c_str ());
230
+ printf (" -ctv, --cache-type-v <t> (default: %s)\n " , join (transform_to_str (cmd_params_defaults.type_v , ggml_type_name), " ," ).c_str ());
222
231
printf (" -t, --threads <n> (default: %s)\n " , join (cmd_params_defaults.n_threads , " ," ).c_str ());
223
232
printf (" -ngl, --n-gpu-layers <n> (default: %s)\n " , join (cmd_params_defaults.n_gpu_layers , " ," ).c_str ());
224
233
printf (" -sm, --split-mode <none|layer|row> (default: %s)\n " , join (transform_to_str (cmd_params_defaults.split_mode , split_mode_str), " ," ).c_str ());
@@ -304,6 +313,17 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
304
313
}
305
314
auto p = split<int >(argv[i], split_delim);
306
315
params.n_gen .insert (params.n_gen .end (), p.begin (), p.end ());
316
+ } else if (arg == " -pg" ) {
317
+ if (++i >= argc) {
318
+ invalid_param = true ;
319
+ break ;
320
+ }
321
+ auto p = split<std::string>(argv[i], ' ,' );
322
+ if (p.size () != 2 ) {
323
+ invalid_param = true ;
324
+ break ;
325
+ }
326
+ params.n_pg .push_back ({std::stoi (p[0 ]), std::stoi (p[1 ])});
307
327
} else if (arg == " -b" || arg == " --batch-size" ) {
308
328
if (++i >= argc) {
309
329
invalid_param = true ;
@@ -493,6 +513,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
493
513
if (params.model .empty ()) { params.model = cmd_params_defaults.model ; }
494
514
if (params.n_prompt .empty ()) { params.n_prompt = cmd_params_defaults.n_prompt ; }
495
515
if (params.n_gen .empty ()) { params.n_gen = cmd_params_defaults.n_gen ; }
516
+ if (params.n_pg .empty ()) { params.n_pg = cmd_params_defaults.n_pg ; }
496
517
if (params.n_batch .empty ()) { params.n_batch = cmd_params_defaults.n_batch ; }
497
518
if (params.n_ubatch .empty ()) { params.n_ubatch = cmd_params_defaults.n_ubatch ; }
498
519
if (params.type_k .empty ()) { params.type_k = cmd_params_defaults.type_k ; }
@@ -632,6 +653,31 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
632
653
};
633
654
instances.push_back (instance);
634
655
}
656
+
657
+ for (const auto & n_pg : params.n_pg ) {
658
+ if (n_pg.first == 0 && n_pg.second == 0 ) {
659
+ continue ;
660
+ }
661
+ cmd_params_instance instance = {
662
+ /* .model = */ m,
663
+ /* .n_prompt = */ n_pg.first ,
664
+ /* .n_gen = */ n_pg.second ,
665
+ /* .n_batch = */ nb,
666
+ /* .n_ubatch = */ nub,
667
+ /* .type_k = */ tk,
668
+ /* .type_v = */ tv,
669
+ /* .n_threads = */ nt,
670
+ /* .n_gpu_layers = */ nl,
671
+ /* .split_mode = */ sm,
672
+ /* .main_gpu = */ mg,
673
+ /* .no_kv_offload= */ nkvo,
674
+ /* .flash_attn = */ fa,
675
+ /* .tensor_split = */ ts,
676
+ /* .use_mmap = */ mmp,
677
+ /* .embeddings = */ embd,
678
+ };
679
+ instances.push_back (instance);
680
+ }
635
681
}
636
682
637
683
return instances;
@@ -965,6 +1011,9 @@ struct markdown_printer : public printer {
965
1011
if (field == " n_gpu_layers" ) {
966
1012
return 3 ;
967
1013
}
1014
+ if (field == " test" ) {
1015
+ return 13 ;
1016
+ }
968
1017
969
1018
int width = std::max ((int )field.length (), 10 );
970
1019
@@ -1091,12 +1140,11 @@ struct markdown_printer : public printer {
1091
1140
value = test::get_backend ();
1092
1141
} else if (field == " test" ) {
1093
1142
if (t.n_prompt > 0 && t.n_gen == 0 ) {
1094
- snprintf (buf, sizeof (buf), " pp %d" , t.n_prompt );
1143
+ snprintf (buf, sizeof (buf), " pp%d" , t.n_prompt );
1095
1144
} else if (t.n_gen > 0 && t.n_prompt == 0 ) {
1096
- snprintf (buf, sizeof (buf), " tg %d" , t.n_gen );
1145
+ snprintf (buf, sizeof (buf), " tg%d" , t.n_gen );
1097
1146
} else {
1098
- assert (false );
1099
- exit (1 );
1147
+ snprintf (buf, sizeof (buf), " pp%d+tg%d" , t.n_prompt , t.n_gen );
1100
1148
}
1101
1149
value = buf;
1102
1150
} else if (field == " t/s" ) {
@@ -1297,6 +1345,7 @@ int main(int argc, char ** argv) {
1297
1345
llama_kv_cache_clear (ctx);
1298
1346
1299
1347
uint64_t t_start = get_time_ns ();
1348
+
1300
1349
if (t.n_prompt > 0 ) {
1301
1350
test_prompt (ctx, t.n_prompt , 0 , t.n_batch , t.n_threads );
1302
1351
}
0 commit comments