Skip to content

Commit 9ca8698

Browse files
committed
batched-bench : add fattn arg
1 parent c16a7c2 commit 9ca8698

File tree

1 file changed

+17
-11
lines changed

1 file changed

+17
-11
lines changed

examples/batched-bench/batched-bench.cpp

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ int main(int argc, char ** argv) {
3232
gpt_params params;
3333

3434
if (argc == 1 || argv[1][0] == '-') {
35-
printf("usage: %s MODEL_PATH [N_KV_MAX] [N_BATCH] [N_UBATCH] [IS_PP_SHARED] [NGL] <PP> <TG> <PL>\n" , argv[0]);
35+
printf("usage: %s MODEL_PATH [N_KV_MAX] [N_BATCH] [N_UBATCH] [FATTN] [IS_PP_SHARED] [NGL] <PP> <TG> <PL>\n" , argv[0]);
3636
printf(" <PP>, <TG> and PL are comma-separated lists of numbers without spaces\n\n");
3737
printf(" example: %s ggml-model-f16.gguf 2048 2048 512 0 999 128,256,512 128,256 1,2,4,8,16,32\n\n", argv[0]);
3838
return 1 ;
@@ -41,6 +41,7 @@ int main(int argc, char ** argv) {
4141
int n_kv_max = 2048;
4242
int n_batch = 2048;
4343
int n_ubatch = 512;
44+
bool flash_attn = false;
4445
int is_pp_shared = 0;
4546
int n_gpu_layers = 0;
4647

@@ -66,23 +67,27 @@ int main(int argc, char ** argv) {
6667
}
6768

6869
if (argc >= 6) {
69-
is_pp_shared = std::atoi(argv[5]);
70+
flash_attn = std::atoi(argv[5]);
7071
}
7172

7273
if (argc >= 7) {
73-
n_gpu_layers = std::atoi(argv[6]);
74+
is_pp_shared = std::atoi(argv[6]);
7475
}
7576

7677
if (argc >= 8) {
77-
n_pp = parse_list(argv[7]);
78+
n_gpu_layers = std::atoi(argv[7]);
7879
}
7980

8081
if (argc >= 9) {
81-
n_tg = parse_list(argv[8]);
82+
n_pp = parse_list(argv[8]);
8283
}
8384

8485
if (argc >= 10) {
85-
n_pl = parse_list(argv[9]);
86+
n_tg = parse_list(argv[9]);
87+
}
88+
89+
if (argc >= 11) {
90+
n_pl = parse_list(argv[10]);
8691
}
8792

8893
// init LLM
@@ -108,10 +113,11 @@ int main(int argc, char ** argv) {
108113

109114
llama_context_params ctx_params = llama_context_default_params();
110115

111-
ctx_params.seed = 1234;
112-
ctx_params.n_ctx = n_kv_max;
113-
ctx_params.n_batch = n_batch;
114-
ctx_params.n_ubatch = n_ubatch;
116+
ctx_params.seed = 1234;
117+
ctx_params.n_ctx = n_kv_max;
118+
ctx_params.n_batch = n_batch;
119+
ctx_params.n_ubatch = n_ubatch;
120+
ctx_params.flash_attn = flash_attn;
115121

116122
ctx_params.n_threads = params.n_threads;
117123
ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
@@ -169,7 +175,7 @@ int main(int argc, char ** argv) {
169175
}
170176

171177
LOG_TEE("\n");
172-
LOG_TEE("%s: n_kv_max = %d, n_batch = %d, n_ubatch = %d, is_pp_shared = %d, n_gpu_layers = %d, n_threads = %u, n_threads_batch = %u\n", __func__, n_kv_max, n_batch, n_ubatch, is_pp_shared, n_gpu_layers, ctx_params.n_threads, ctx_params.n_threads_batch);
178+
LOG_TEE("%s: n_kv_max = %d, n_batch = %d, n_ubatch = %d, flash_attn = %d, is_pp_shared = %d, n_gpu_layers = %d, n_threads = %u, n_threads_batch = %u\n", __func__, n_kv_max, n_batch, n_ubatch, flash_attn, is_pp_shared, n_gpu_layers, ctx_params.n_threads, ctx_params.n_threads_batch);
173179
LOG_TEE("\n");
174180

175181
LOG_TEE("|%6s | %6s | %4s | %6s | %8s | %8s | %8s | %8s | %8s | %8s |\n", "PP", "TG", "B", "N_KV", "T_PP s", "S_PP t/s", "T_TG s", "S_TG t/s", "T s", "S t/s");

0 commit comments

Comments
 (0)