Skip to content

Commit b283804

Browse files
authored
bench : handle decode errors (#13548)
ggml-ci
1 parent aa48e37 commit b283804

File tree

1 file changed

+39
-9
lines changed

1 file changed

+39
-9
lines changed

tools/llama-bench/llama-bench.cpp

Lines changed: 39 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1736,7 +1736,7 @@ struct sql_printer : public printer {
17361736
}
17371737
};
17381738

1739-
static void test_prompt(llama_context * ctx, int n_prompt, int n_batch, int n_threads) {
1739+
static bool test_prompt(llama_context * ctx, int n_prompt, int n_batch, int n_threads) {
17401740
llama_set_n_threads(ctx, n_threads, n_threads);
17411741

17421742
const llama_model * model = llama_get_model(ctx);
@@ -1753,14 +1753,19 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_batch, int n_th
17531753
for (int i = 1; i < n_tokens; i++) {
17541754
tokens[i] = std::rand() % n_vocab;
17551755
}
1756-
llama_decode(ctx, llama_batch_get_one(tokens.data(), n_tokens));
1756+
int res = llama_decode(ctx, llama_batch_get_one(tokens.data(), n_tokens));
1757+
if (res != 0) {
1758+
fprintf(stderr, "%s: failed to decode prompt batch, res = %d\n", __func__, res);
1759+
return false;
1760+
}
17571761
n_processed += n_tokens;
17581762
}
17591763

17601764
llama_synchronize(ctx);
1765+
return true;
17611766
}
17621767

1763-
static void test_gen(llama_context * ctx, int n_gen, int n_threads) {
1768+
static bool test_gen(llama_context * ctx, int n_gen, int n_threads) {
17641769
llama_set_n_threads(ctx, n_threads, n_threads);
17651770

17661771
const llama_model * model = llama_get_model(ctx);
@@ -1770,10 +1775,15 @@ static void test_gen(llama_context * ctx, int n_gen, int n_threads) {
17701775
llama_token token = llama_vocab_get_add_bos(vocab) ? llama_vocab_bos(vocab) : std::rand() % n_vocab;
17711776

17721777
for (int i = 0; i < n_gen; i++) {
1773-
llama_decode(ctx, llama_batch_get_one(&token, 1));
1778+
int res = llama_decode(ctx, llama_batch_get_one(&token, 1));
1779+
if (res != 0) {
1780+
fprintf(stderr, "%s: failed to decode generation batch, res = %d\n", __func__, res);
1781+
return false;
1782+
}
17741783
llama_synchronize(ctx);
17751784
token = std::rand() % n_vocab;
17761785
}
1786+
return true;
17771787
}
17781788

17791789
static void llama_null_log_callback(enum ggml_log_level level, const char * text, void * user_data) {
@@ -1917,13 +1927,21 @@ int main(int argc, char ** argv) {
19171927
fprintf(stderr, "llama-bench: benchmark %d/%zu: warmup prompt run\n", params_idx, params_count);
19181928
}
19191929
//test_prompt(ctx, std::min(t.n_batch, std::min(t.n_prompt, 32)), 0, t.n_batch, t.n_threads);
1920-
test_prompt(ctx, t.n_prompt, t.n_batch, t.n_threads);
1930+
bool res = test_prompt(ctx, t.n_prompt, t.n_batch, t.n_threads);
1931+
if (!res) {
1932+
fprintf(stderr, "%s: error: failed to run prompt warmup\n", __func__);
1933+
exit(1);
1934+
}
19211935
}
19221936
if (t.n_gen > 0) {
19231937
if (params.progress) {
19241938
fprintf(stderr, "llama-bench: benchmark %d/%zu: warmup generation run\n", params_idx, params_count);
19251939
}
1926-
test_gen(ctx, 1, t.n_threads);
1940+
bool res = test_gen(ctx, 1, t.n_threads);
1941+
if (!res) {
1942+
fprintf(stderr, "%s: error: failed to run gen warmup\n", __func__);
1943+
exit(1);
1944+
}
19271945
}
19281946

19291947
for (int i = 0; i < params.reps; i++) {
@@ -1934,7 +1952,11 @@ int main(int argc, char ** argv) {
19341952
fprintf(stderr, "llama-bench: benchmark %d/%zu: depth run %d/%d\n", params_idx, params_count,
19351953
i + 1, params.reps);
19361954
}
1937-
test_prompt(ctx, t.n_depth, t.n_batch, t.n_threads);
1955+
bool res = test_prompt(ctx, t.n_depth, t.n_batch, t.n_threads);
1956+
if (!res) {
1957+
fprintf(stderr, "%s: error: failed to run depth\n", __func__);
1958+
exit(1);
1959+
}
19381960
}
19391961

19401962
uint64_t t_start = get_time_ns();
@@ -1944,14 +1966,22 @@ int main(int argc, char ** argv) {
19441966
fprintf(stderr, "llama-bench: benchmark %d/%zu: prompt run %d/%d\n", params_idx, params_count,
19451967
i + 1, params.reps);
19461968
}
1947-
test_prompt(ctx, t.n_prompt, t.n_batch, t.n_threads);
1969+
bool res = test_prompt(ctx, t.n_prompt, t.n_batch, t.n_threads);
1970+
if (!res) {
1971+
fprintf(stderr, "%s: error: failed to run prompt\n", __func__);
1972+
exit(1);
1973+
}
19481974
}
19491975
if (t.n_gen > 0) {
19501976
if (params.progress) {
19511977
fprintf(stderr, "llama-bench: benchmark %d/%zu: generation run %d/%d\n", params_idx, params_count,
19521978
i + 1, params.reps);
19531979
}
1954-
test_gen(ctx, t.n_gen, t.n_threads);
1980+
bool res = test_gen(ctx, t.n_gen, t.n_threads);
1981+
if (!res) {
1982+
fprintf(stderr, "%s: error: failed to run gen\n", __func__);
1983+
exit(1);
1984+
}
19551985
}
19561986

19571987
uint64_t t_ns = get_time_ns() - t_start;

0 commit comments

Comments
 (0)