Skip to content

Commit 650d92e

Browse files
move code to common, CLI arg, lookup-stats
1 parent 76e91a6 commit 650d92e

File tree

6 files changed

+450
-302
lines changed

6 files changed

+450
-302
lines changed

Makefile

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -786,6 +786,8 @@ lookup: examples/lookup/lookup.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS)
786786
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
787787
$(CXX) $(CXXFLAGS) -c examples/lookup/lookup-create.cpp -o $(call GET_OBJ_FILE, examples/lookup/lookup-create.cpp)
788788
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, examples/lookup/lookup-create.cpp) -o lookup-create $(LDFLAGS)
789+
$(CXX) $(CXXFLAGS) -c examples/lookup/lookup-stats.cpp -o $(call GET_OBJ_FILE, examples/lookup/lookup-stats.cpp)
790+
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, examples/lookup/lookup-stats.cpp) -o lookup-stats $(LDFLAGS)
789791

790792
passkey: examples/passkey/passkey.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS)
791793
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)

common/common.cpp

Lines changed: 233 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -714,6 +714,12 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
714714
if (params.logdir.back() != DIRECTORY_SEPARATOR) {
715715
params.logdir += DIRECTORY_SEPARATOR;
716716
}
717+
} else if (arg == "-lcs" || arg == "--lookup-cache-static") {
718+
if (++i >= argc) {
719+
invalid_param = true;
720+
break;
721+
}
722+
params.lookup_cache_static = argv[i];
717723
} else if (arg == "--save-all-logits" || arg == "--kl-divergence-base") {
718724
if (++i >= argc) {
719725
invalid_param = true;
@@ -1093,6 +1099,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
10931099
printf(" draft model for speculative decoding\n");
10941100
printf(" -ld LOGDIR, --logdir LOGDIR\n");
10951101
printf(" path under which to save YAML logs (no logging if unset)\n");
1102+
printf(" -lcs FNAME, --lookup-cache-static FNAME\n");
1103+
printf(" path to static lookup cache to use for lookup decoding\n");
10961104
printf(" --override-kv KEY=TYPE:VALUE\n");
10971105
printf(" advanced option to override model metadata by key. may be specified multiple times.\n");
10981106
printf(" types: int, float, bool. example: --override-kv tokenizer.ggml.add_bos_token=bool:false\n");
@@ -1851,3 +1859,228 @@ void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size) {
18511859

18521860
printf("\n=== Done dumping\n");
18531861
}
1862+
1863+
void llama_ngram_cache_update(std::vector<llama_ngram_cache> & ncs, int ngram_min,
1864+
std::vector<llama_token> & inp, int nnew, bool print_progress) {
1865+
const int64_t t_start_ms = ggml_time_ms();
1866+
const int ngram_max = ngram_min + ncs.size()-1;
1867+
const int inp_size = inp.size();
1868+
1869+
for (int ngram_size = ngram_min; ngram_size <= ngram_max; ++ngram_size) {
1870+
llama_ngram_cache & nc = ncs[ngram_size - ngram_min];
1871+
1872+
const int i_start = std::max(inp_size - nnew, ngram_size);
1873+
for (int i = i_start; i < inp_size; ++i) {
1874+
const int ngram_start = i - ngram_size;
1875+
uint64_t ngram = inp[ngram_start];
1876+
for (int j = ngram_start+1; j < ngram_start + ngram_size; ++j) { // FIXME
1877+
const uint64_t ngram_part = inp[j];
1878+
ngram <<= 16;
1879+
ngram |= ngram_part;
1880+
}
1881+
const llama_token token = inp[i];
1882+
1883+
llama_ngram_cache::iterator part_it = nc.find(ngram);
1884+
if (part_it == nc.end()) {
1885+
llama_ngram_cache_part part;
1886+
part.emplace(token, 1);
1887+
nc.emplace(ngram, part);
1888+
} else {
1889+
llama_ngram_cache_part::iterator token_count_it = part_it->second.find(token);
1890+
if (token_count_it == part_it->second.end()) {
1891+
part_it->second.emplace(token, 1);
1892+
} else {
1893+
token_count_it->second++;
1894+
}
1895+
}
1896+
if (print_progress && i % 10000000 == 0) {
1897+
const int64_t t_now_ms = ggml_time_ms();
1898+
const int64_t eta_ms = (inp_size - i) * (t_now_ms - t_start_ms) / i;
1899+
const int64_t eta_min = eta_ms / (60*1000);
1900+
const int64_t eta_s = (eta_ms - eta_min) / 1000;
1901+
1902+
fprintf(stderr, "%s: %d/%d done, ETA: %02ld:%02ld\n", __func__, i, inp_size, eta_min, eta_s);
1903+
}
1904+
}
1905+
}
1906+
}
1907+
1908+
// Helper function to get a token from the combined, speculative sequence of inp and draft.
1909+
static llama_token get_token(const std::vector<llama_token> & inp, const std::vector<llama_token> & draft, const size_t i) {
1910+
return i < inp.size() ? inp[i] : draft[1 + i - inp.size()];
1911+
};
1912+
1913+
// If sample size or percentage in context are below these thresholds the draft is aborted early:
1914+
constexpr int draft_min_sample_size[LLAMA_NGRAM_MAX] = { 2, 2, 1, 1};
1915+
constexpr int draft_min_percent[LLAMA_NGRAM_MAX] = {50, 50, 50, 50};
1916+
1917+
void llama_ngram_cache_draft(
1918+
std::vector<llama_token> & inp, std::vector<llama_token> & draft, int n_draft,
1919+
std::vector<llama_ngram_cache> & ncs_t1, int ngram_min, llama_ngram_cache & nc_t2
1920+
) {
1921+
const int inp_size = inp.size();
1922+
const int ngram_max = ngram_min + ncs_t1.size()-1;
1923+
1924+
while ((int) draft.size()-1 < n_draft) {
1925+
bool draft_success = false;
1926+
1927+
const int ngram_start_t2 = inp_size-2 + draft.size()-1;
1928+
uint64_t ngram_t2 = get_token(inp, draft, ngram_start_t2);
1929+
for (int j = ngram_start_t2+1; j < ngram_start_t2 + 2; ++j) {
1930+
const uint64_t token = get_token(inp, draft, j);
1931+
ngram_t2 <<= 16;
1932+
ngram_t2 |= token;
1933+
}
1934+
llama_ngram_cache::iterator part_t2_it = nc_t2.find(ngram_t2);
1935+
llama_ngram_cache_part part_t2;
1936+
if (part_t2_it != nc_t2.end()) {
1937+
part_t2 = part_t2_it->second;
1938+
}
1939+
1940+
for (int ngram_size = ngram_max; ngram_size >= ngram_min; --ngram_size) {
1941+
if (ngram_size > inp_size) {
1942+
continue;
1943+
}
1944+
1945+
llama_ngram_cache & nc_t1 = ncs_t1[ngram_size - ngram_min];
1946+
1947+
const int ngram_start_t1 = inp_size-ngram_size + draft.size()-1;
1948+
uint64_t ngram_t1 = get_token(inp, draft, ngram_start_t1);
1949+
for (int j = ngram_start_t1+1; j < ngram_start_t1 + ngram_size; ++j) {
1950+
const uint64_t token = get_token(inp, draft, j);
1951+
ngram_t1 <<= 16;
1952+
ngram_t1 |= token;
1953+
}
1954+
1955+
llama_ngram_cache::iterator part_t1_it = nc_t1.find(ngram_t1);
1956+
if (part_t1_it == nc_t1.end()) {
1957+
continue;
1958+
}
1959+
const llama_ngram_cache_part part_t1 = part_t1_it->second;
1960+
1961+
int max_count_t1 = 0;
1962+
int max_count_t2 = 0;
1963+
int sum_count_t1 = 0;
1964+
llama_token max_token = -1;
1965+
1966+
for (std::pair<llama_token, int> token_count_t1 : part_t1) {
1967+
const llama_token token = token_count_t1.first;
1968+
1969+
llama_ngram_cache_part::iterator token_count_t2_it = part_t2.find(token);
1970+
const int32_t count_t1 = token_count_t1.second;
1971+
const int32_t count_t2 = token_count_t2_it != part_t2.end() ? 100*token_count_t2_it->second : 1;
1972+
1973+
if (count_t1*count_t2 > max_count_t1*max_count_t2) {
1974+
max_token = token;
1975+
max_count_t1 = count_t1;
1976+
max_count_t2 = count_t2;
1977+
}
1978+
sum_count_t1 += count_t1;
1979+
}
1980+
// Skip this candidate if the sample size is too low:
1981+
if (sum_count_t1 < draft_min_sample_size[ngram_size-1]) {
1982+
continue;
1983+
}
1984+
// skip this candidate if the empirically most likely token following this token is not likely enough:
1985+
if (100*max_count_t1 < draft_min_percent[ngram_size-1]*sum_count_t1) {
1986+
continue;
1987+
}
1988+
1989+
LOG(" - draft candidate: token=%d count=%d\n", max_token, max_count_t1);
1990+
draft.push_back(max_token);
1991+
draft_success = true;
1992+
break;
1993+
}
1994+
1995+
if (!draft_success) {
1996+
int max_count_t2 = 0;
1997+
int sum_count_t2 = 0;
1998+
llama_token max_token = -1;
1999+
2000+
for (std::pair<llama_token, int> token_count_t2 : part_t2) {
2001+
const llama_token token = token_count_t2.first;
2002+
const int32_t count_t2 = token_count_t2.second;
2003+
2004+
if (count_t2 > max_count_t2) {
2005+
max_token = token;
2006+
max_count_t2 = count_t2;
2007+
}
2008+
sum_count_t2 += count_t2;
2009+
}
2010+
2011+
// Skip this candidate if the sample size is too low:
2012+
if (sum_count_t2 < draft_min_sample_size[2-1]) {
2013+
break;
2014+
}
2015+
// skip this candidate if the empirically most likely token following this token is not likely enough:
2016+
if (100*max_count_t2 < draft_min_percent[2-1]*sum_count_t2) {
2017+
break;
2018+
}
2019+
2020+
LOG(" - draft candidate: token=%d count=%d\n", max_token, max_count_t2);
2021+
draft.push_back(max_token);
2022+
draft_success = true;
2023+
break;
2024+
}
2025+
2026+
if (!draft_success) {
2027+
break;
2028+
}
2029+
}
2030+
};
2031+
2032+
void llama_ngram_cache_save(std::vector<llama_ngram_cache> & ngram_cache, std::string & filename) {
2033+
GGML_ASSERT(ngram_cache.size() == 1);
2034+
std::ofstream file_out(filename, std::ios::binary);
2035+
for (std::pair<uint64_t, llama_ngram_cache_part> item : ngram_cache[0]) {
2036+
const uint64_t ngram = item.first;
2037+
llama_ngram_cache_part token_counts = item.second;
2038+
GGML_ASSERT(!token_counts.empty());
2039+
const int32_t ntokens = token_counts.size();
2040+
2041+
2042+
file_out.write(reinterpret_cast<const char *>(&ngram), sizeof(uint64_t));
2043+
file_out.write(reinterpret_cast<const char *>(&ntokens), sizeof(int32_t));
2044+
for (std::pair<llama_token, int32_t> item2 : token_counts) {
2045+
const llama_token token = item2.first;
2046+
const int32_t count = item2.second;
2047+
file_out.write(reinterpret_cast<const char *>(&token), sizeof(llama_token));
2048+
file_out.write(reinterpret_cast<const char *>(&count), sizeof(int32_t));
2049+
}
2050+
}
2051+
2052+
}
2053+
2054+
llama_ngram_cache llama_ngram_cache_load(std::string & filename) {
2055+
std::ifstream hashmap_file(filename, std::ios::binary);
2056+
if (!hashmap_file) {
2057+
fprintf(stderr, "error: failed to open file '%s'\n", filename.c_str());
2058+
exit(1);
2059+
}
2060+
llama_ngram_cache ngram_cache;
2061+
2062+
uint64_t ngram;
2063+
int32_t ntokens;
2064+
llama_token token;
2065+
int32_t count;
2066+
2067+
char * ngramc = reinterpret_cast<char*>(&ngram);
2068+
char * ntokensc = reinterpret_cast<char*>(&ntokens);
2069+
char * tokenc = reinterpret_cast<char*>(&token);
2070+
char * countc = reinterpret_cast<char*>(&count);
2071+
while(hashmap_file.read(ngramc, sizeof(uint64_t))) {
2072+
GGML_ASSERT(hashmap_file.read(ntokensc, sizeof(int32_t)));
2073+
llama_ngram_cache_part token_counts;
2074+
2075+
for (int i = 0; i < ntokens; ++i) {
2076+
GGML_ASSERT(hashmap_file.read(tokenc, sizeof(llama_token)));
2077+
GGML_ASSERT(hashmap_file.read(countc, sizeof(int32_t)));
2078+
token_counts.emplace(token, count);
2079+
}
2080+
2081+
ngram_cache.emplace(ngram, token_counts);
2082+
}
2083+
GGML_ASSERT(hashmap_file.eof());
2084+
2085+
return ngram_cache;
2086+
}

common/common.h

Lines changed: 39 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -84,17 +84,18 @@ struct gpt_params {
8484
// // sampling parameters
8585
struct llama_sampling_params sparams;
8686

87-
std::string model = "models/7B/ggml-model-f16.gguf"; // model path
88-
std::string model_draft = ""; // draft model for speculative decoding
89-
std::string model_alias = "unknown"; // model alias
90-
std::string prompt = "";
91-
std::string prompt_file = ""; // store the external prompt file name
92-
std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state
93-
std::string input_prefix = ""; // string to prefix user inputs with
94-
std::string input_suffix = ""; // string to suffix user inputs with
95-
std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
96-
std::string logdir = ""; // directory in which to save YAML log files
97-
std::string logits_file = ""; // file for saving *all* logits
87+
std::string model = "models/7B/ggml-model-f16.gguf"; // model path
88+
std::string model_draft = ""; // draft model for speculative decoding
89+
std::string model_alias = "unknown"; // model alias
90+
std::string prompt = "";
91+
std::string prompt_file = ""; // store the external prompt file name
92+
std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state
93+
std::string input_prefix = ""; // string to prefix user inputs with
94+
std::string input_suffix = ""; // string to suffix user inputs with
95+
std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
96+
std::string logdir = ""; // directory in which to save YAML log files
97+
std::string lookup_cache_static = ""; // path of ngram cache file for lookup decoding
98+
std::string logits_file = ""; // file for saving *all* logits
9899

99100
std::vector<llama_model_kv_override> kv_overrides;
100101

@@ -260,3 +261,30 @@ void dump_kv_cache_view(const llama_kv_cache_view & view, int row_size = 80);
260261

261262
// Dump the KV cache view showing individual sequences in each cell (long output).
262263
void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size = 40);
264+
265+
#define LLAMA_NGRAM_MAX 4
266+
267+
// Data structures to map n-grams to empirical token probabilities:
268+
typedef std::unordered_map<llama_token, int32_t> llama_ngram_cache_part; // token -> number of times token has been seen
269+
typedef std::unordered_map<uint64_t, llama_ngram_cache_part> llama_ngram_cache; // n-gram -> empirical distribution of following tokens
270+
// n-grams are encoded as 64 bit integers with each of the 4 16 bit sections representing a token id.
271+
// This way no custom hashing function for the n-grams is needed.
272+
273+
// Update an ngram cache with tokens.
274+
// ncs = ngram caches: the hashmaps to modify.
275+
// ngram_min/ngram_max: the min/max size of the ngrams in ncs.
276+
// inp_data: the token sequence on which the hashmaps are based.
277+
// nnew: how many new tokens have been appended to inp_data since the last call to this function.
278+
// print_progress: whether to print progress to stderr
279+
//
280+
// In order to get correct results inp_data can ONLY BE APPENDED TO.
281+
// Changes in the middle need a complete rebuild.
282+
void llama_ngram_cache_update(std::vector<llama_ngram_cache> & ncs, int ngram_min,
283+
std::vector<llama_token> & inp_data, int nnew, bool print_progress);
284+
285+
void llama_ngram_cache_draft(
286+
std::vector<llama_token> & inp, std::vector<llama_token> & draft, int n_draft,
287+
std::vector<llama_ngram_cache> & ncs_t1, int ngram_min, llama_ngram_cache & nc_t2);
288+
289+
void llama_ngram_cache_save(std::vector<llama_ngram_cache> & ngram_cache, std::string & filename);
290+
llama_ngram_cache llama_ngram_cache_load(std::string & filename);

0 commit comments

Comments
 (0)