Skip to content

Commit 0e500a3

Browse files
committed
Adding support for the --numa argument for benchmarking.
1 parent 8425001 commit 0e500a3

File tree

1 file changed

+16
-0
lines changed

1 file changed

+16
-0
lines changed

examples/llama-bench/llama-bench.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,8 @@ struct cmd_params {
178178
std::vector<std::vector<float>> tensor_split;
179179
std::vector<bool> use_mmap;
180180
std::vector<bool> embeddings;
181+
//I'm not sure if it's safe to call llama_numa_init multiple times, so this isn't a vector.
182+
ggml_numa_strategy numa = GGML_NUMA_STRATEGY_DISABLED;
181183
int reps;
182184
bool verbose;
183185
output_formats output_format;
@@ -200,6 +202,7 @@ static const cmd_params cmd_params_defaults = {
200202
/* tensor_split */ {std::vector<float>(llama_max_devices(), 0.0f)},
201203
/* use_mmap */ {true},
202204
/* embeddings */ {false},
205+
/* numa */ GGML_NUMA_STRATEGY_DISABLED,
203206
/* reps */ 5,
204207
/* verbose */ false,
205208
/* output_format */ MARKDOWN
@@ -224,6 +227,7 @@ static void print_usage(int /* argc */, char ** argv) {
224227
printf(" -nkvo, --no-kv-offload <0|1> (default: %s)\n", join(cmd_params_defaults.no_kv_offload, ",").c_str());
225228
printf(" -fa, --flash-attn <0|1> (default: %s)\n", join(cmd_params_defaults.flash_attn, ",").c_str());
226229
printf(" -mmp, --mmap <0|1> (default: %s)\n", join(cmd_params_defaults.use_mmap, ",").c_str());
230+
printf(" --numa <distribute|isolate|numactl> (default: disabled)\n");
227231
printf(" -embd, --embeddings <0|1> (default: %s)\n", join(cmd_params_defaults.embeddings, ",").c_str());
228232
printf(" -ts, --tensor-split <ts0/ts1/..> (default: 0)\n");
229233
printf(" -r, --repetitions <n> (default: %d)\n", cmd_params_defaults.reps);
@@ -396,6 +400,17 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
396400
}
397401
auto p = split<bool>(argv[i], split_delim);
398402
params.no_kv_offload.insert(params.no_kv_offload.end(), p.begin(), p.end());
403+
} else if (arg == "--numa") {
404+
if (++i >= argc) {
405+
invalid_param = true;
406+
break;
407+
} else {
408+
std::string value(argv[i]);
409+
/**/ if (value == "distribute" || value == "" ) { params.numa = GGML_NUMA_STRATEGY_DISTRIBUTE; }
410+
else if (value == "isolate") { params.numa = GGML_NUMA_STRATEGY_ISOLATE; }
411+
else if (value == "numactl") { params.numa = GGML_NUMA_STRATEGY_NUMACTL; }
412+
else { invalid_param = true; break; }
413+
}
399414
} else if (arg == "-fa" || arg == "--flash-attn") {
400415
if (++i >= argc) {
401416
invalid_param = true;
@@ -1215,6 +1230,7 @@ int main(int argc, char ** argv) {
12151230
llama_log_set(llama_null_log_callback, NULL);
12161231
}
12171232
llama_backend_init();
1233+
llama_numa_init(params.numa);
12181234

12191235
// initialize printer
12201236
std::unique_ptr<printer> p;

0 commit comments

Comments
 (0)