@@ -1879,9 +1879,9 @@ void llama_ngram_cache_update(llama_ngram_cache & ngram_cache, int ngram_min, in
1879
1879
const int i_start = std::max (inp_size - nnew, ngram_size);
1880
1880
for (int i = i_start; i < inp_size; ++i) {
1881
1881
const int ngram_start = i - ngram_size;
1882
- uint64_t ngram = inp[ngram_start];
1882
+ llama_ngram ngram = inp[ngram_start];
1883
1883
for (int j = ngram_start+1 ; j < ngram_start + ngram_size; ++j) { // FIXME
1884
- const uint64_t ngram_part = inp[j];
1884
+ const llama_ngram ngram_part = inp[j];
1885
1885
ngram <<= 16 ;
1886
1886
ngram |= ngram_part;
1887
1887
}
@@ -1904,7 +1904,7 @@ void llama_ngram_cache_update(llama_ngram_cache & ngram_cache, int ngram_min, in
1904
1904
const int64_t t_now_ms = ggml_time_ms ();
1905
1905
const int64_t eta_ms = (inp_size - i) * (t_now_ms - t_start_ms) / i;
1906
1906
const int64_t eta_min = eta_ms / (60 *1000 );
1907
- const int64_t eta_s = (eta_ms - eta_min) / 1000 ;
1907
+ const int64_t eta_s = (eta_ms - 60 * 1000 * eta_min) / 1000 ;
1908
1908
1909
1909
fprintf (stderr, " %s: %d/%d done, ETA: %02ld:%02ld\n " , __func__, i, inp_size, eta_min, eta_s);
1910
1910
}
@@ -1917,75 +1917,77 @@ static llama_token get_token(const std::vector<llama_token> & inp, const std::ve
1917
1917
return i < inp.size () ? inp[i] : draft[1 + i - inp.size ()];
1918
1918
};
1919
1919
1920
- // If sample size or percentage in context are below these thresholds the draft is aborted early:
1921
- constexpr int draft_min_sample_size_t1 [LLAMA_NGRAM_MAX] = { 2 , 2 , 1 , 1 };
1922
- constexpr int draft_min_percent_t1 [LLAMA_NGRAM_MAX] = {66 , 50 , 50 , 50 };
1923
- constexpr int draft_min_sample_size_t2 [LLAMA_NGRAM_MAX] = { 4 , 3 , 2 , 2 };
1924
- constexpr int draft_min_percent_t2 [LLAMA_NGRAM_MAX] = {75 , 66 , 66 , 66 };
1920
+ // If sample size or percentage are below these thresholds the draft is aborted early:
1921
+ constexpr int draft_min_sample_size_lax [LLAMA_NGRAM_MAX] = { 2 , 2 , 1 , 1 };
1922
+ constexpr int draft_min_percent_lax [LLAMA_NGRAM_MAX] = {66 , 50 , 50 , 50 };
1923
+ constexpr int draft_min_sample_size_strict [LLAMA_NGRAM_MAX] = { 4 , 3 , 2 , 2 };
1924
+ constexpr int draft_min_percent_strict [LLAMA_NGRAM_MAX] = {75 , 66 , 66 , 66 };
1925
1925
1926
- static llama_token try_draft (llama_ngram_cache & nc_primary, const uint64_t ngram_primary) {
1927
- llama_ngram_cache::iterator part_primary_it = nc_primary.find (ngram_primary);
1928
- if (part_primary_it == nc_primary.end ()) {
1926
+ // Helper function that tries to draft a token from only the static ngram cache:
1927
+ static llama_token try_draft (llama_ngram_cache & nc_static, const llama_ngram ngram_static) {
1928
+ llama_ngram_cache::iterator part_static_it = nc_static.find (ngram_static);
1929
+ if (part_static_it == nc_static.end ()) {
1929
1930
return -1 ;
1930
1931
}
1931
- const llama_ngram_cache_part part_primary = part_primary_it ->second ;
1932
+ const llama_ngram_cache_part part_static = part_static_it ->second ;
1932
1933
1933
- int max_count_primary = 0 ;
1934
- int sum_count_primary = 0 ;
1935
- llama_token max_token = -1 ;
1934
+ int max_count_static = 0 ;
1935
+ int sum_count_static = 0 ;
1936
+ llama_token max_token = -1 ;
1936
1937
1937
- for (std::pair<llama_token, int > token_count_primary : part_primary ) {
1938
- const llama_token token = token_count_primary .first ;
1939
- const int32_t count_primary = token_count_primary .second ;
1938
+ for (std::pair<llama_token, int > token_count_static : part_static ) {
1939
+ const llama_token token = token_count_static .first ;
1940
+ const int32_t count_static = token_count_static .second ;
1940
1941
1941
- if (count_primary > max_count_primary ) {
1942
- max_token = token;
1943
- max_count_primary = count_primary ;
1942
+ if (count_static > max_count_static ) {
1943
+ max_token = token;
1944
+ max_count_static = count_static ;
1944
1945
}
1945
- sum_count_primary += count_primary ;
1946
+ sum_count_static += count_static ;
1946
1947
}
1947
1948
1948
- if (sum_count_primary < draft_min_sample_size_t1[ 2 -1 ]) {
1949
+ if (sum_count_static < draft_min_sample_size_lax[LLAMA_NGRAM_STATIC -1 ]) {
1949
1950
return -1 ;
1950
1951
}
1951
- if (100 *max_count_primary < draft_min_percent_t1[ 2 -1 ]*sum_count_primary ) {
1952
+ if (100 *max_count_static < draft_min_percent_lax[LLAMA_NGRAM_STATIC -1 ]*sum_count_static ) {
1952
1953
return -1 ;
1953
1954
}
1954
1955
return max_token;
1955
1956
}
1956
1957
1958
+ // Try to draft a token from primary cache (context/dynamic), validate with static cache:
1957
1959
static llama_token try_draft (
1958
- llama_ngram_cache & nc_primary, const std::vector<uint64_t > & ngrams_primary, llama_ngram_cache_part & part_validate ,
1960
+ llama_ngram_cache & nc_primary, const std::vector<llama_ngram > & ngrams_primary, llama_ngram_cache_part & part_static ,
1959
1961
const int * min_sample_size, const int * min_percent) {
1960
1962
1961
1963
llama_token drafted_token = -1 ;
1962
1964
1963
1965
for (int i = ngrams_primary.size ()-1 ; i >= 0 && drafted_token == -1 ; --i) {
1964
- const uint64_t ngram_primary = ngrams_primary[i];
1966
+ const llama_ngram ngram_primary = ngrams_primary[i];
1965
1967
1966
1968
llama_ngram_cache::iterator part_primary_it = nc_primary.find (ngram_primary);
1967
1969
if (part_primary_it == nc_primary.end ()) {
1968
1970
continue ;
1969
1971
}
1970
1972
const llama_ngram_cache_part part_primary = part_primary_it->second ;
1971
1973
1972
- int max_count_primary = 0 ;
1973
- int max_count_validate = 0 ;
1974
- int sum_count_primary = 0 ;
1975
- llama_token max_token = -1 ;
1974
+ int max_count_primary = 0 ;
1975
+ int max_count_static = 0 ;
1976
+ int sum_count_primary = 0 ;
1977
+ llama_token max_token = -1 ;
1976
1978
1977
1979
for (std::pair<llama_token, int > token_count_primary : part_primary) {
1978
1980
const llama_token token = token_count_primary.first ;
1979
1981
1980
- llama_ngram_cache_part::iterator token_count_validate_it = part_validate .find (token);
1982
+ llama_ngram_cache_part::iterator token_count_static_it = part_static .find (token);
1981
1983
1982
- const int32_t count_primary = token_count_primary.second ;
1983
- const int32_t count_validate = token_count_validate_it != part_validate .end () ? 100 *token_count_validate_it ->second : 1 ;
1984
+ const int32_t count_primary = token_count_primary.second ;
1985
+ const int32_t count_static = token_count_static_it != part_static .end () ? 100 *token_count_static_it ->second : 1 ;
1984
1986
1985
- if (count_primary*count_validate > max_count_primary*max_count_validate ) {
1986
- max_token = token;
1987
- max_count_primary = count_primary;
1988
- max_count_validate = count_validate ;
1987
+ if (count_primary*count_static > max_count_primary*max_count_static ) {
1988
+ max_token = token;
1989
+ max_count_primary = count_primary;
1990
+ max_count_static = count_static ;
1989
1991
}
1990
1992
sum_count_primary += count_primary;
1991
1993
}
@@ -2004,49 +2006,51 @@ static llama_token try_draft(
2004
2006
2005
2007
void llama_ngram_cache_draft (
2006
2008
std::vector<llama_token> & inp, std::vector<llama_token> & draft, int n_draft, int ngram_min, int ngram_max,
2007
- llama_ngram_cache & nc_t1 , llama_ngram_cache & nc_t2 , llama_ngram_cache & nc_t3
2009
+ llama_ngram_cache & nc_context , llama_ngram_cache & nc_dynamic , llama_ngram_cache & nc_static
2008
2010
) {
2011
+ GGML_ASSERT (draft.size () == 1 );
2009
2012
const int inp_size = inp.size ();
2010
2013
2011
- if (inp_size < 2 ) {
2014
+ if (inp_size < LLAMA_NGRAM_STATIC ) {
2012
2015
return ;
2013
2016
}
2014
2017
2015
2018
while ((int ) draft.size ()-1 < n_draft) {
2016
2019
llama_token drafted_token = -1 ;
2017
2020
2018
- const int ngram_start_t23 = inp_size-2 + draft.size ()-1 ;
2019
- uint64_t ngram_t23 = get_token (inp, draft, ngram_start_t23 );
2020
- for (int j = ngram_start_t23 +1 ; j < ngram_start_t23 + 2 ; ++j) {
2021
- const uint64_t token = get_token (inp, draft, j);
2022
- ngram_t23 <<= 16 ;
2023
- ngram_t23 |= token;
2021
+ const int ngram_start_static = inp_size-LLAMA_NGRAM_STATIC + draft.size ()-1 ;
2022
+ llama_ngram ngram_static = get_token (inp, draft, ngram_start_static );
2023
+ for (int j = ngram_start_static +1 ; j < ngram_start_static + LLAMA_NGRAM_STATIC ; ++j) {
2024
+ const llama_ngram token = get_token (inp, draft, j);
2025
+ ngram_static <<= 16 ;
2026
+ ngram_static |= token;
2024
2027
}
2025
- llama_ngram_cache::iterator part_t3_it = nc_t3 .find (ngram_t23 );
2026
- llama_ngram_cache_part part_t3 ;
2027
- if (part_t3_it != nc_t3 .end ()) {
2028
- part_t3 = part_t3_it ->second ;
2028
+ llama_ngram_cache::iterator part_static_it = nc_static .find (ngram_static );
2029
+ llama_ngram_cache_part part_static ;
2030
+ if (part_static_it != nc_static .end ()) {
2031
+ part_static = part_static_it ->second ;
2029
2032
}
2030
2033
2031
- std::vector<uint64_t > ngrams_t12;
2032
- for (int ngram_size = ngram_min; ngram_size <= ngram_max; ++ngram_size) {
2033
- const int ngram_start_t12 = inp_size-ngram_size + draft.size ()-1 ;
2034
- uint64_t ngram_t12 = get_token (inp, draft, ngram_start_t12);
2035
- for (int j = ngram_start_t12+1 ; j < ngram_start_t12 + ngram_size; ++j) {
2036
- const uint64_t token = get_token (inp, draft, j);
2037
- ngram_t12 <<= 16 ;
2038
- ngram_t12 |= token;
2039
- }
2040
- ngrams_t12.push_back (ngram_t12);
2034
+ // cd = context + dynamic
2035
+ std::vector<llama_ngram> ngrams_cd;
2036
+ for (int ngram_size_cd = ngram_min; ngram_size_cd <= ngram_max; ++ngram_size_cd) {
2037
+ const int ngram_start_cd = inp_size-ngram_size_cd + draft.size ()-1 ;
2038
+ llama_ngram ngram_cd = get_token (inp, draft, ngram_start_cd);
2039
+ for (int j = ngram_start_cd+1 ; j < ngram_start_cd + ngram_size_cd; ++j) {
2040
+ const llama_ngram token = get_token (inp, draft, j);
2041
+ ngram_cd <<= 16 ;
2042
+ ngram_cd |= token;
2043
+ }
2044
+ ngrams_cd.push_back (ngram_cd);
2041
2045
}
2042
2046
if (drafted_token == -1 ) {
2043
- drafted_token = try_draft (nc_t1, ngrams_t12, part_t3, draft_min_sample_size_t1, draft_min_percent_t1 );
2047
+ drafted_token = try_draft (nc_context, ngrams_cd, part_static, draft_min_sample_size_lax, draft_min_percent_lax );
2044
2048
}
2045
2049
if (drafted_token == -1 ) {
2046
- drafted_token = try_draft (nc_t2, ngrams_t12, part_t3, draft_min_sample_size_t2, draft_min_percent_t2 );
2050
+ drafted_token = try_draft (nc_dynamic, ngrams_cd, part_static, draft_min_sample_size_strict, draft_min_percent_strict );
2047
2051
}
2048
2052
if (drafted_token == -1 ) {
2049
- drafted_token = try_draft (nc_t3, ngram_t23 );
2053
+ drafted_token = try_draft (nc_static, ngram_static );
2050
2054
}
2051
2055
2052
2056
if (drafted_token == -1 ) {
@@ -2060,14 +2064,13 @@ void llama_ngram_cache_draft(
2060
2064
2061
2065
void llama_ngram_cache_save (llama_ngram_cache & ngram_cache, std::string & filename) {
2062
2066
std::ofstream file_out (filename, std::ios::binary);
2063
- for (std::pair<uint64_t , llama_ngram_cache_part> item : ngram_cache) {
2064
- const uint64_t ngram = item.first ;
2067
+ for (std::pair<llama_ngram , llama_ngram_cache_part> item : ngram_cache) {
2068
+ const llama_ngram ngram = item.first ;
2065
2069
llama_ngram_cache_part token_counts = item.second ;
2066
2070
GGML_ASSERT (!token_counts.empty ());
2067
2071
const int32_t ntokens = token_counts.size ();
2068
2072
2069
-
2070
- file_out.write (reinterpret_cast <const char *>(&ngram), sizeof (uint64_t ));
2073
+ file_out.write (reinterpret_cast <const char *>(&ngram), sizeof (llama_ngram));
2071
2074
file_out.write (reinterpret_cast <const char *>(&ntokens), sizeof (int32_t ));
2072
2075
for (std::pair<llama_token, int32_t > item2 : token_counts) {
2073
2076
const llama_token token = item2.first ;
@@ -2086,16 +2089,16 @@ llama_ngram_cache llama_ngram_cache_load(std::string & filename) {
2086
2089
}
2087
2090
llama_ngram_cache ngram_cache;
2088
2091
2089
- uint64_t ngram;
2090
- int32_t ntokens;
2092
+ llama_ngram ngram;
2093
+ int32_t ntokens;
2091
2094
llama_token token;
2092
- int32_t count;
2095
+ int32_t count;
2093
2096
2094
2097
char * ngramc = reinterpret_cast <char *>(&ngram);
2095
2098
char * ntokensc = reinterpret_cast <char *>(&ntokens);
2096
2099
char * tokenc = reinterpret_cast <char *>(&token);
2097
2100
char * countc = reinterpret_cast <char *>(&count);
2098
- while (hashmap_file.read (ngramc, sizeof (uint64_t ))) {
2101
+ while (hashmap_file.read (ngramc, sizeof (llama_ngram ))) {
2099
2102
GGML_ASSERT (hashmap_file.read (ntokensc, sizeof (int32_t )));
2100
2103
llama_ngram_cache_part token_counts;
2101
2104
@@ -2113,8 +2116,8 @@ llama_ngram_cache llama_ngram_cache_load(std::string & filename) {
2113
2116
}
2114
2117
2115
2118
void llama_ngram_cache_merge (llama_ngram_cache & ngram_cache_target, llama_ngram_cache & ngram_cache_add) {
2116
- for (std::pair<uint64_t , llama_ngram_cache_part> ngram_part : ngram_cache_add) {
2117
- const uint64_t ngram = ngram_part.first ;
2119
+ for (std::pair<llama_ngram , llama_ngram_cache_part> ngram_part : ngram_cache_add) {
2120
+ const llama_ngram ngram = ngram_part.first ;
2118
2121
llama_ngram_cache_part part = ngram_part.second ;
2119
2122
2120
2123
llama_ngram_cache::iterator part_merged_it = ngram_cache_target.find (ngram);
0 commit comments