@@ -714,6 +714,12 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
714
714
if (params.logdir .back () != DIRECTORY_SEPARATOR) {
715
715
params.logdir += DIRECTORY_SEPARATOR;
716
716
}
717
+ } else if (arg == " -lcs" || arg == " --lookup-cache-static" ) {
718
+ if (++i >= argc) {
719
+ invalid_param = true ;
720
+ break ;
721
+ }
722
+ params.lookup_cache_static = argv[i];
717
723
} else if (arg == " --save-all-logits" || arg == " --kl-divergence-base" ) {
718
724
if (++i >= argc) {
719
725
invalid_param = true ;
@@ -1093,6 +1099,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
1093
1099
printf (" draft model for speculative decoding\n " );
1094
1100
printf (" -ld LOGDIR, --logdir LOGDIR\n " );
1095
1101
printf (" path under which to save YAML logs (no logging if unset)\n " );
1102
+ printf (" -lcs FNAME, --lookup-cache-static FNAME\n " );
1103
+ printf (" path to static lookup cache to use for lookup decoding\n " );
1096
1104
printf (" --override-kv KEY=TYPE:VALUE\n " );
1097
1105
printf (" advanced option to override model metadata by key. may be specified multiple times.\n " );
1098
1106
printf (" types: int, float, bool. example: --override-kv tokenizer.ggml.add_bos_token=bool:false\n " );
@@ -1851,3 +1859,228 @@ void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size) {
1851
1859
1852
1860
printf (" \n === Done dumping\n " );
1853
1861
}
1862
+
1863
+ void llama_ngram_cache_update (std::vector<llama_ngram_cache> & ncs, int ngram_min,
1864
+ std::vector<llama_token> & inp, int nnew, bool print_progress) {
1865
+ const int64_t t_start_ms = ggml_time_ms ();
1866
+ const int ngram_max = ngram_min + ncs.size ()-1 ;
1867
+ const int inp_size = inp.size ();
1868
+
1869
+ for (int ngram_size = ngram_min; ngram_size <= ngram_max; ++ngram_size) {
1870
+ llama_ngram_cache & nc = ncs[ngram_size - ngram_min];
1871
+
1872
+ const int i_start = std::max (inp_size - nnew, ngram_size);
1873
+ for (int i = i_start; i < inp_size; ++i) {
1874
+ const int ngram_start = i - ngram_size;
1875
+ uint64_t ngram = inp[ngram_start];
1876
+ for (int j = ngram_start+1 ; j < ngram_start + ngram_size; ++j) { // FIXME
1877
+ const uint64_t ngram_part = inp[j];
1878
+ ngram <<= 16 ;
1879
+ ngram |= ngram_part;
1880
+ }
1881
+ const llama_token token = inp[i];
1882
+
1883
+ llama_ngram_cache::iterator part_it = nc.find (ngram);
1884
+ if (part_it == nc.end ()) {
1885
+ llama_ngram_cache_part part;
1886
+ part.emplace (token, 1 );
1887
+ nc.emplace (ngram, part);
1888
+ } else {
1889
+ llama_ngram_cache_part::iterator token_count_it = part_it->second .find (token);
1890
+ if (token_count_it == part_it->second .end ()) {
1891
+ part_it->second .emplace (token, 1 );
1892
+ } else {
1893
+ token_count_it->second ++;
1894
+ }
1895
+ }
1896
+ if (print_progress && i % 10000000 == 0 ) {
1897
+ const int64_t t_now_ms = ggml_time_ms ();
1898
+ const int64_t eta_ms = (inp_size - i) * (t_now_ms - t_start_ms) / i;
1899
+ const int64_t eta_min = eta_ms / (60 *1000 );
1900
+ const int64_t eta_s = (eta_ms - eta_min) / 1000 ;
1901
+
1902
+ fprintf (stderr, " %s: %d/%d done, ETA: %02ld:%02ld\n " , __func__, i, inp_size, eta_min, eta_s);
1903
+ }
1904
+ }
1905
+ }
1906
+ }
1907
+
1908
+ // Helper function to get a token from the combined, speculative sequence of inp and draft.
1909
+ static llama_token get_token (const std::vector<llama_token> & inp, const std::vector<llama_token> & draft, const size_t i) {
1910
+ return i < inp.size () ? inp[i] : draft[1 + i - inp.size ()];
1911
+ };
1912
+
1913
+ // If sample size or percentage in context are below these thresholds the draft is aborted early:
1914
+ constexpr int draft_min_sample_size[LLAMA_NGRAM_MAX] = { 2 , 2 , 1 , 1 };
1915
+ constexpr int draft_min_percent[LLAMA_NGRAM_MAX] = {50 , 50 , 50 , 50 };
1916
+
1917
+ void llama_ngram_cache_draft (
1918
+ std::vector<llama_token> & inp, std::vector<llama_token> & draft, int n_draft,
1919
+ std::vector<llama_ngram_cache> & ncs_t1, int ngram_min, llama_ngram_cache & nc_t2
1920
+ ) {
1921
+ const int inp_size = inp.size ();
1922
+ const int ngram_max = ngram_min + ncs_t1.size ()-1 ;
1923
+
1924
+ while ((int ) draft.size ()-1 < n_draft) {
1925
+ bool draft_success = false ;
1926
+
1927
+ const int ngram_start_t2 = inp_size-2 + draft.size ()-1 ;
1928
+ uint64_t ngram_t2 = get_token (inp, draft, ngram_start_t2);
1929
+ for (int j = ngram_start_t2+1 ; j < ngram_start_t2 + 2 ; ++j) {
1930
+ const uint64_t token = get_token (inp, draft, j);
1931
+ ngram_t2 <<= 16 ;
1932
+ ngram_t2 |= token;
1933
+ }
1934
+ llama_ngram_cache::iterator part_t2_it = nc_t2.find (ngram_t2);
1935
+ llama_ngram_cache_part part_t2;
1936
+ if (part_t2_it != nc_t2.end ()) {
1937
+ part_t2 = part_t2_it->second ;
1938
+ }
1939
+
1940
+ for (int ngram_size = ngram_max; ngram_size >= ngram_min; --ngram_size) {
1941
+ if (ngram_size > inp_size) {
1942
+ continue ;
1943
+ }
1944
+
1945
+ llama_ngram_cache & nc_t1 = ncs_t1[ngram_size - ngram_min];
1946
+
1947
+ const int ngram_start_t1 = inp_size-ngram_size + draft.size ()-1 ;
1948
+ uint64_t ngram_t1 = get_token (inp, draft, ngram_start_t1);
1949
+ for (int j = ngram_start_t1+1 ; j < ngram_start_t1 + ngram_size; ++j) {
1950
+ const uint64_t token = get_token (inp, draft, j);
1951
+ ngram_t1 <<= 16 ;
1952
+ ngram_t1 |= token;
1953
+ }
1954
+
1955
+ llama_ngram_cache::iterator part_t1_it = nc_t1.find (ngram_t1);
1956
+ if (part_t1_it == nc_t1.end ()) {
1957
+ continue ;
1958
+ }
1959
+ const llama_ngram_cache_part part_t1 = part_t1_it->second ;
1960
+
1961
+ int max_count_t1 = 0 ;
1962
+ int max_count_t2 = 0 ;
1963
+ int sum_count_t1 = 0 ;
1964
+ llama_token max_token = -1 ;
1965
+
1966
+ for (std::pair<llama_token, int > token_count_t1 : part_t1) {
1967
+ const llama_token token = token_count_t1.first ;
1968
+
1969
+ llama_ngram_cache_part::iterator token_count_t2_it = part_t2.find (token);
1970
+ const int32_t count_t1 = token_count_t1.second ;
1971
+ const int32_t count_t2 = token_count_t2_it != part_t2.end () ? 100 *token_count_t2_it->second : 1 ;
1972
+
1973
+ if (count_t1*count_t2 > max_count_t1*max_count_t2) {
1974
+ max_token = token;
1975
+ max_count_t1 = count_t1;
1976
+ max_count_t2 = count_t2;
1977
+ }
1978
+ sum_count_t1 += count_t1;
1979
+ }
1980
+ // Skip this candidate if the sample size is too low:
1981
+ if (sum_count_t1 < draft_min_sample_size[ngram_size-1 ]) {
1982
+ continue ;
1983
+ }
1984
+ // skip this candidate if the empirically most likely token following this token is not likely enough:
1985
+ if (100 *max_count_t1 < draft_min_percent[ngram_size-1 ]*sum_count_t1) {
1986
+ continue ;
1987
+ }
1988
+
1989
+ LOG (" - draft candidate: token=%d count=%d\n " , max_token, max_count_t1);
1990
+ draft.push_back (max_token);
1991
+ draft_success = true ;
1992
+ break ;
1993
+ }
1994
+
1995
+ if (!draft_success) {
1996
+ int max_count_t2 = 0 ;
1997
+ int sum_count_t2 = 0 ;
1998
+ llama_token max_token = -1 ;
1999
+
2000
+ for (std::pair<llama_token, int > token_count_t2 : part_t2) {
2001
+ const llama_token token = token_count_t2.first ;
2002
+ const int32_t count_t2 = token_count_t2.second ;
2003
+
2004
+ if (count_t2 > max_count_t2) {
2005
+ max_token = token;
2006
+ max_count_t2 = count_t2;
2007
+ }
2008
+ sum_count_t2 += count_t2;
2009
+ }
2010
+
2011
+ // Skip this candidate if the sample size is too low:
2012
+ if (sum_count_t2 < draft_min_sample_size[2 -1 ]) {
2013
+ break ;
2014
+ }
2015
+ // skip this candidate if the empirically most likely token following this token is not likely enough:
2016
+ if (100 *max_count_t2 < draft_min_percent[2 -1 ]*sum_count_t2) {
2017
+ break ;
2018
+ }
2019
+
2020
+ LOG (" - draft candidate: token=%d count=%d\n " , max_token, max_count_t2);
2021
+ draft.push_back (max_token);
2022
+ draft_success = true ;
2023
+ break ;
2024
+ }
2025
+
2026
+ if (!draft_success) {
2027
+ break ;
2028
+ }
2029
+ }
2030
+ };
2031
+
2032
+ void llama_ngram_cache_save (std::vector<llama_ngram_cache> & ngram_cache, std::string & filename) {
2033
+ GGML_ASSERT (ngram_cache.size () == 1 );
2034
+ std::ofstream file_out (filename, std::ios::binary);
2035
+ for (std::pair<uint64_t , llama_ngram_cache_part> item : ngram_cache[0 ]) {
2036
+ const uint64_t ngram = item.first ;
2037
+ llama_ngram_cache_part token_counts = item.second ;
2038
+ GGML_ASSERT (!token_counts.empty ());
2039
+ const int32_t ntokens = token_counts.size ();
2040
+
2041
+
2042
+ file_out.write (reinterpret_cast <const char *>(&ngram), sizeof (uint64_t ));
2043
+ file_out.write (reinterpret_cast <const char *>(&ntokens), sizeof (int32_t ));
2044
+ for (std::pair<llama_token, int32_t > item2 : token_counts) {
2045
+ const llama_token token = item2.first ;
2046
+ const int32_t count = item2.second ;
2047
+ file_out.write (reinterpret_cast <const char *>(&token), sizeof (llama_token));
2048
+ file_out.write (reinterpret_cast <const char *>(&count), sizeof (int32_t ));
2049
+ }
2050
+ }
2051
+
2052
+ }
2053
+
2054
+ llama_ngram_cache llama_ngram_cache_load (std::string & filename) {
2055
+ std::ifstream hashmap_file (filename, std::ios::binary);
2056
+ if (!hashmap_file) {
2057
+ fprintf (stderr, " error: failed to open file '%s'\n " , filename.c_str ());
2058
+ exit (1 );
2059
+ }
2060
+ llama_ngram_cache ngram_cache;
2061
+
2062
+ uint64_t ngram;
2063
+ int32_t ntokens;
2064
+ llama_token token;
2065
+ int32_t count;
2066
+
2067
+ char * ngramc = reinterpret_cast <char *>(&ngram);
2068
+ char * ntokensc = reinterpret_cast <char *>(&ntokens);
2069
+ char * tokenc = reinterpret_cast <char *>(&token);
2070
+ char * countc = reinterpret_cast <char *>(&count);
2071
+ while (hashmap_file.read (ngramc, sizeof (uint64_t ))) {
2072
+ GGML_ASSERT (hashmap_file.read (ntokensc, sizeof (int32_t )));
2073
+ llama_ngram_cache_part token_counts;
2074
+
2075
+ for (int i = 0 ; i < ntokens; ++i) {
2076
+ GGML_ASSERT (hashmap_file.read (tokenc, sizeof (llama_token)));
2077
+ GGML_ASSERT (hashmap_file.read (countc, sizeof (int32_t )));
2078
+ token_counts.emplace (token, count);
2079
+ }
2080
+
2081
+ ngram_cache.emplace (ngram, token_counts);
2082
+ }
2083
+ GGML_ASSERT (hashmap_file.eof ());
2084
+
2085
+ return ngram_cache;
2086
+ }
0 commit comments