@@ -178,6 +178,8 @@ struct cmd_params {
178
178
std::vector<std::vector<float >> tensor_split;
179
179
std::vector<bool > use_mmap;
180
180
std::vector<bool > embeddings;
181
+ // I'm not sure if it's safe to call llama_numa_init multiple times, so this isn't a vector.
182
+ ggml_numa_strategy numa = GGML_NUMA_STRATEGY_DISABLED;
181
183
int reps;
182
184
bool verbose;
183
185
output_formats output_format;
@@ -200,6 +202,7 @@ static const cmd_params cmd_params_defaults = {
200
202
/* tensor_split */ {std::vector<float >(llama_max_devices (), 0 .0f )},
201
203
/* use_mmap */ {true },
202
204
/* embeddings */ {false },
205
+ /* numa */ GGML_NUMA_STRATEGY_DISABLED,
203
206
/* reps */ 5 ,
204
207
/* verbose */ false ,
205
208
/* output_format */ MARKDOWN
@@ -224,6 +227,7 @@ static void print_usage(int /* argc */, char ** argv) {
224
227
printf (" -nkvo, --no-kv-offload <0|1> (default: %s)\n " , join (cmd_params_defaults.no_kv_offload , " ," ).c_str ());
225
228
printf (" -fa, --flash-attn <0|1> (default: %s)\n " , join (cmd_params_defaults.flash_attn , " ," ).c_str ());
226
229
printf (" -mmp, --mmap <0|1> (default: %s)\n " , join (cmd_params_defaults.use_mmap , " ," ).c_str ());
230
+ printf (" --numa <distribute|isolate|numactl> (default: disabled)\n " );
227
231
printf (" -embd, --embeddings <0|1> (default: %s)\n " , join (cmd_params_defaults.embeddings , " ," ).c_str ());
228
232
printf (" -ts, --tensor-split <ts0/ts1/..> (default: 0)\n " );
229
233
printf (" -r, --repetitions <n> (default: %d)\n " , cmd_params_defaults.reps );
@@ -396,6 +400,17 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
396
400
}
397
401
auto p = split<bool >(argv[i], split_delim);
398
402
params.no_kv_offload .insert (params.no_kv_offload .end (), p.begin (), p.end ());
403
+ } else if (arg == " --numa" ) {
404
+ if (++i >= argc) {
405
+ invalid_param = true ;
406
+ break ;
407
+ } else {
408
+ std::string value (argv[i]);
409
+ /* */ if (value == " distribute" || value == " " ) { params.numa = GGML_NUMA_STRATEGY_DISTRIBUTE; }
410
+ else if (value == " isolate" ) { params.numa = GGML_NUMA_STRATEGY_ISOLATE; }
411
+ else if (value == " numactl" ) { params.numa = GGML_NUMA_STRATEGY_NUMACTL; }
412
+ else { invalid_param = true ; break ; }
413
+ }
399
414
} else if (arg == " -fa" || arg == " --flash-attn" ) {
400
415
if (++i >= argc) {
401
416
invalid_param = true ;
@@ -1215,6 +1230,7 @@ int main(int argc, char ** argv) {
1215
1230
llama_log_set (llama_null_log_callback, NULL );
1216
1231
}
1217
1232
llama_backend_init ();
1233
+ llama_numa_init (params.numa );
1218
1234
1219
1235
// initialize printer
1220
1236
std::unique_ptr<printer> p;
0 commit comments