@@ -1736,7 +1736,7 @@ struct sql_printer : public printer {
1736
1736
}
1737
1737
};
1738
1738
1739
- static void test_prompt (llama_context * ctx, int n_prompt, int n_batch, int n_threads) {
1739
+ static bool test_prompt (llama_context * ctx, int n_prompt, int n_batch, int n_threads) {
1740
1740
llama_set_n_threads (ctx, n_threads, n_threads);
1741
1741
1742
1742
const llama_model * model = llama_get_model (ctx);
@@ -1753,14 +1753,19 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_batch, int n_th
1753
1753
for (int i = 1 ; i < n_tokens; i++) {
1754
1754
tokens[i] = std::rand () % n_vocab;
1755
1755
}
1756
- llama_decode (ctx, llama_batch_get_one (tokens.data (), n_tokens));
1756
+ int res = llama_decode (ctx, llama_batch_get_one (tokens.data (), n_tokens));
1757
+ if (res != 0 ) {
1758
+ fprintf (stderr, " %s: failed to decode prompt batch, res = %d\n " , __func__, res);
1759
+ return false ;
1760
+ }
1757
1761
n_processed += n_tokens;
1758
1762
}
1759
1763
1760
1764
llama_synchronize (ctx);
1765
+ return true ;
1761
1766
}
1762
1767
1763
- static void test_gen (llama_context * ctx, int n_gen, int n_threads) {
1768
+ static bool test_gen (llama_context * ctx, int n_gen, int n_threads) {
1764
1769
llama_set_n_threads (ctx, n_threads, n_threads);
1765
1770
1766
1771
const llama_model * model = llama_get_model (ctx);
@@ -1770,10 +1775,15 @@ static void test_gen(llama_context * ctx, int n_gen, int n_threads) {
1770
1775
llama_token token = llama_vocab_get_add_bos (vocab) ? llama_vocab_bos (vocab) : std::rand () % n_vocab;
1771
1776
1772
1777
for (int i = 0 ; i < n_gen; i++) {
1773
- llama_decode (ctx, llama_batch_get_one (&token, 1 ));
1778
+ int res = llama_decode (ctx, llama_batch_get_one (&token, 1 ));
1779
+ if (res != 0 ) {
1780
+ fprintf (stderr, " %s: failed to decode generation batch, res = %d\n " , __func__, res);
1781
+ return false ;
1782
+ }
1774
1783
llama_synchronize (ctx);
1775
1784
token = std::rand () % n_vocab;
1776
1785
}
1786
+ return true ;
1777
1787
}
1778
1788
1779
1789
static void llama_null_log_callback (enum ggml_log_level level, const char * text, void * user_data) {
@@ -1917,13 +1927,21 @@ int main(int argc, char ** argv) {
1917
1927
fprintf (stderr, " llama-bench: benchmark %d/%zu: warmup prompt run\n " , params_idx, params_count);
1918
1928
}
1919
1929
// test_prompt(ctx, std::min(t.n_batch, std::min(t.n_prompt, 32)), 0, t.n_batch, t.n_threads);
1920
- test_prompt (ctx, t.n_prompt , t.n_batch , t.n_threads );
1930
+ bool res = test_prompt (ctx, t.n_prompt , t.n_batch , t.n_threads );
1931
+ if (!res) {
1932
+ fprintf (stderr, " %s: error: failed to run prompt warmup\n " , __func__);
1933
+ exit (1 );
1934
+ }
1921
1935
}
1922
1936
if (t.n_gen > 0 ) {
1923
1937
if (params.progress ) {
1924
1938
fprintf (stderr, " llama-bench: benchmark %d/%zu: warmup generation run\n " , params_idx, params_count);
1925
1939
}
1926
- test_gen (ctx, 1 , t.n_threads );
1940
+ bool res = test_gen (ctx, 1 , t.n_threads );
1941
+ if (!res) {
1942
+ fprintf (stderr, " %s: error: failed to run gen warmup\n " , __func__);
1943
+ exit (1 );
1944
+ }
1927
1945
}
1928
1946
1929
1947
for (int i = 0 ; i < params.reps ; i++) {
@@ -1934,7 +1952,11 @@ int main(int argc, char ** argv) {
1934
1952
fprintf (stderr, " llama-bench: benchmark %d/%zu: depth run %d/%d\n " , params_idx, params_count,
1935
1953
i + 1 , params.reps );
1936
1954
}
1937
- test_prompt (ctx, t.n_depth , t.n_batch , t.n_threads );
1955
+ bool res = test_prompt (ctx, t.n_depth , t.n_batch , t.n_threads );
1956
+ if (!res) {
1957
+ fprintf (stderr, " %s: error: failed to run depth\n " , __func__);
1958
+ exit (1 );
1959
+ }
1938
1960
}
1939
1961
1940
1962
uint64_t t_start = get_time_ns ();
@@ -1944,14 +1966,22 @@ int main(int argc, char ** argv) {
1944
1966
fprintf (stderr, " llama-bench: benchmark %d/%zu: prompt run %d/%d\n " , params_idx, params_count,
1945
1967
i + 1 , params.reps );
1946
1968
}
1947
- test_prompt (ctx, t.n_prompt , t.n_batch , t.n_threads );
1969
+ bool res = test_prompt (ctx, t.n_prompt , t.n_batch , t.n_threads );
1970
+ if (!res) {
1971
+ fprintf (stderr, " %s: error: failed to run prompt\n " , __func__);
1972
+ exit (1 );
1973
+ }
1948
1974
}
1949
1975
if (t.n_gen > 0 ) {
1950
1976
if (params.progress ) {
1951
1977
fprintf (stderr, " llama-bench: benchmark %d/%zu: generation run %d/%d\n " , params_idx, params_count,
1952
1978
i + 1 , params.reps );
1953
1979
}
1954
- test_gen (ctx, t.n_gen , t.n_threads );
1980
+ bool res = test_gen (ctx, t.n_gen , t.n_threads );
1981
+ if (!res) {
1982
+ fprintf (stderr, " %s: error: failed to run gen\n " , __func__);
1983
+ exit (1 );
1984
+ }
1955
1985
}
1956
1986
1957
1987
uint64_t t_ns = get_time_ns () - t_start;
0 commit comments