Skip to content

Commit 69a16ef

Browse files
refactor, add comments
1 parent 1ed6213 commit 69a16ef

File tree

7 files changed

+141
-121
lines changed

7 files changed

+141
-121
lines changed

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -659,7 +659,7 @@ libllama.a: llama.o ggml.o $(OBJS) $(COMMON_DEPS)
659659
ar rcs libllama.a llama.o ggml.o $(OBJS) $(COMMON_DEPS)
660660

661661
clean:
662-
rm -vrf *.o tests/*.o *.so *.a *.dll benchmark-matmult common/build-info.cpp *.dot $(COV_TARGETS) $(BUILD_TARGETS) $(TEST_TARGETS)
662+
rm -vrf *.o tests/*.o *.so *.a *.dll benchmark-matmult lookup-create lookup-merge lookup-stats common/build-info.cpp *.dot $(COV_TARGETS) $(BUILD_TARGETS) $(TEST_TARGETS)
663663
find examples pocs -type f -name "*.o" -delete
664664

665665
#

common/common.cpp

Lines changed: 75 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -1879,9 +1879,9 @@ void llama_ngram_cache_update(llama_ngram_cache & ngram_cache, int ngram_min, in
18791879
const int i_start = std::max(inp_size - nnew, ngram_size);
18801880
for (int i = i_start; i < inp_size; ++i) {
18811881
const int ngram_start = i - ngram_size;
1882-
uint64_t ngram = inp[ngram_start];
1882+
llama_ngram ngram = inp[ngram_start];
18831883
for (int j = ngram_start+1; j < ngram_start + ngram_size; ++j) { // FIXME
1884-
const uint64_t ngram_part = inp[j];
1884+
const llama_ngram ngram_part = inp[j];
18851885
ngram <<= 16;
18861886
ngram |= ngram_part;
18871887
}
@@ -1904,7 +1904,7 @@ void llama_ngram_cache_update(llama_ngram_cache & ngram_cache, int ngram_min, in
19041904
const int64_t t_now_ms = ggml_time_ms();
19051905
const int64_t eta_ms = (inp_size - i) * (t_now_ms - t_start_ms) / i;
19061906
const int64_t eta_min = eta_ms / (60*1000);
1907-
const int64_t eta_s = (eta_ms - eta_min) / 1000;
1907+
const int64_t eta_s = (eta_ms - 60*1000*eta_min) / 1000;
19081908

19091909
fprintf(stderr, "%s: %d/%d done, ETA: %02ld:%02ld\n", __func__, i, inp_size, eta_min, eta_s);
19101910
}
@@ -1917,75 +1917,77 @@ static llama_token get_token(const std::vector<llama_token> & inp, const std::ve
19171917
return i < inp.size() ? inp[i] : draft[1 + i - inp.size()];
19181918
};
19191919

1920-
// If sample size or percentage in context are below these thresholds the draft is aborted early:
1921-
constexpr int draft_min_sample_size_t1[LLAMA_NGRAM_MAX] = { 2, 2, 1, 1};
1922-
constexpr int draft_min_percent_t1[LLAMA_NGRAM_MAX] = {66, 50, 50, 50};
1923-
constexpr int draft_min_sample_size_t2[LLAMA_NGRAM_MAX] = { 4, 3, 2, 2};
1924-
constexpr int draft_min_percent_t2[LLAMA_NGRAM_MAX] = {75, 66, 66, 66};
1920+
// If sample size or percentage are below these thresholds the draft is aborted early:
1921+
constexpr int draft_min_sample_size_lax[LLAMA_NGRAM_MAX] = { 2, 2, 1, 1};
1922+
constexpr int draft_min_percent_lax[LLAMA_NGRAM_MAX] = {66, 50, 50, 50};
1923+
constexpr int draft_min_sample_size_strict[LLAMA_NGRAM_MAX] = { 4, 3, 2, 2};
1924+
constexpr int draft_min_percent_strict[LLAMA_NGRAM_MAX] = {75, 66, 66, 66};
19251925

1926-
static llama_token try_draft(llama_ngram_cache & nc_primary, const uint64_t ngram_primary) {
1927-
llama_ngram_cache::iterator part_primary_it = nc_primary.find(ngram_primary);
1928-
if (part_primary_it == nc_primary.end()) {
1926+
// Helper function that tries to draft a token from only the static ngram cache:
1927+
static llama_token try_draft(llama_ngram_cache & nc_static, const llama_ngram ngram_static) {
1928+
llama_ngram_cache::iterator part_static_it = nc_static.find(ngram_static);
1929+
if (part_static_it == nc_static.end()) {
19291930
return -1;
19301931
}
1931-
const llama_ngram_cache_part part_primary = part_primary_it->second;
1932+
const llama_ngram_cache_part part_static = part_static_it->second;
19321933

1933-
int max_count_primary = 0;
1934-
int sum_count_primary = 0;
1935-
llama_token max_token = -1;
1934+
int max_count_static = 0;
1935+
int sum_count_static = 0;
1936+
llama_token max_token = -1;
19361937

1937-
for (std::pair<llama_token, int> token_count_primary : part_primary) {
1938-
const llama_token token = token_count_primary.first;
1939-
const int32_t count_primary = token_count_primary.second;
1938+
for (std::pair<llama_token, int> token_count_static : part_static) {
1939+
const llama_token token = token_count_static.first;
1940+
const int32_t count_static = token_count_static.second;
19401941

1941-
if (count_primary > max_count_primary) {
1942-
max_token = token;
1943-
max_count_primary = count_primary;
1942+
if (count_static > max_count_static) {
1943+
max_token = token;
1944+
max_count_static = count_static;
19441945
}
1945-
sum_count_primary += count_primary;
1946+
sum_count_static += count_static;
19461947
}
19471948

1948-
if (sum_count_primary < draft_min_sample_size_t1[2-1]) {
1949+
if (sum_count_static < draft_min_sample_size_lax[LLAMA_NGRAM_STATIC-1]) {
19491950
return -1;
19501951
}
1951-
if (100*max_count_primary < draft_min_percent_t1[2-1]*sum_count_primary) {
1952+
if (100*max_count_static < draft_min_percent_lax[LLAMA_NGRAM_STATIC-1]*sum_count_static) {
19521953
return -1;
19531954
}
19541955
return max_token;
19551956
}
19561957

1958+
// Try to draft a token from primary cache (context/dynamic), validate with static cache:
19571959
static llama_token try_draft(
1958-
llama_ngram_cache & nc_primary, const std::vector<uint64_t> & ngrams_primary, llama_ngram_cache_part & part_validate,
1960+
llama_ngram_cache & nc_primary, const std::vector<llama_ngram> & ngrams_primary, llama_ngram_cache_part & part_static,
19591961
const int * min_sample_size, const int * min_percent) {
19601962

19611963
llama_token drafted_token = -1;
19621964

19631965
for (int i = ngrams_primary.size()-1; i >= 0 && drafted_token == -1; --i) {
1964-
const uint64_t ngram_primary = ngrams_primary[i];
1966+
const llama_ngram ngram_primary = ngrams_primary[i];
19651967

19661968
llama_ngram_cache::iterator part_primary_it = nc_primary.find(ngram_primary);
19671969
if (part_primary_it == nc_primary.end()) {
19681970
continue;
19691971
}
19701972
const llama_ngram_cache_part part_primary = part_primary_it->second;
19711973

1972-
int max_count_primary = 0;
1973-
int max_count_validate = 0;
1974-
int sum_count_primary = 0;
1975-
llama_token max_token = -1;
1974+
int max_count_primary = 0;
1975+
int max_count_static = 0;
1976+
int sum_count_primary = 0;
1977+
llama_token max_token = -1;
19761978

19771979
for (std::pair<llama_token, int> token_count_primary : part_primary) {
19781980
const llama_token token = token_count_primary.first;
19791981

1980-
llama_ngram_cache_part::iterator token_count_validate_it = part_validate.find(token);
1982+
llama_ngram_cache_part::iterator token_count_static_it = part_static.find(token);
19811983

1982-
const int32_t count_primary = token_count_primary.second;
1983-
const int32_t count_validate = token_count_validate_it != part_validate.end() ? 100*token_count_validate_it->second : 1;
1984+
const int32_t count_primary = token_count_primary.second;
1985+
const int32_t count_static = token_count_static_it != part_static.end() ? 100*token_count_static_it->second : 1;
19841986

1985-
if (count_primary*count_validate > max_count_primary*max_count_validate) {
1986-
max_token = token;
1987-
max_count_primary = count_primary;
1988-
max_count_validate = count_validate;
1987+
if (count_primary*count_static > max_count_primary*max_count_static) {
1988+
max_token = token;
1989+
max_count_primary = count_primary;
1990+
max_count_static = count_static;
19891991
}
19901992
sum_count_primary += count_primary;
19911993
}
@@ -2004,49 +2006,51 @@ static llama_token try_draft(
20042006

20052007
void llama_ngram_cache_draft(
20062008
std::vector<llama_token> & inp, std::vector<llama_token> & draft, int n_draft, int ngram_min, int ngram_max,
2007-
llama_ngram_cache & nc_t1, llama_ngram_cache & nc_t2, llama_ngram_cache & nc_t3
2009+
llama_ngram_cache & nc_context, llama_ngram_cache & nc_dynamic, llama_ngram_cache & nc_static
20082010
) {
2011+
GGML_ASSERT(draft.size() == 1);
20092012
const int inp_size = inp.size();
20102013

2011-
if (inp_size < 2) {
2014+
if (inp_size < LLAMA_NGRAM_STATIC) {
20122015
return;
20132016
}
20142017

20152018
while ((int) draft.size()-1 < n_draft) {
20162019
llama_token drafted_token = -1;
20172020

2018-
const int ngram_start_t23 = inp_size-2 + draft.size()-1;
2019-
uint64_t ngram_t23 = get_token(inp, draft, ngram_start_t23);
2020-
for (int j = ngram_start_t23+1; j < ngram_start_t23 + 2; ++j) {
2021-
const uint64_t token = get_token(inp, draft, j);
2022-
ngram_t23 <<= 16;
2023-
ngram_t23 |= token;
2021+
const int ngram_start_static = inp_size-LLAMA_NGRAM_STATIC + draft.size()-1;
2022+
llama_ngram ngram_static = get_token(inp, draft, ngram_start_static);
2023+
for (int j = ngram_start_static+1; j < ngram_start_static + LLAMA_NGRAM_STATIC; ++j) {
2024+
const llama_ngram token = get_token(inp, draft, j);
2025+
ngram_static <<= 16;
2026+
ngram_static |= token;
20242027
}
2025-
llama_ngram_cache::iterator part_t3_it = nc_t3.find(ngram_t23);
2026-
llama_ngram_cache_part part_t3;
2027-
if (part_t3_it != nc_t3.end()) {
2028-
part_t3 = part_t3_it->second;
2028+
llama_ngram_cache::iterator part_static_it = nc_static.find(ngram_static);
2029+
llama_ngram_cache_part part_static;
2030+
if (part_static_it != nc_static.end()) {
2031+
part_static = part_static_it->second;
20292032
}
20302033

2031-
std::vector<uint64_t> ngrams_t12;
2032-
for (int ngram_size = ngram_min; ngram_size <= ngram_max; ++ngram_size) {
2033-
const int ngram_start_t12 = inp_size-ngram_size + draft.size()-1;
2034-
uint64_t ngram_t12 = get_token(inp, draft, ngram_start_t12);
2035-
for (int j = ngram_start_t12+1; j < ngram_start_t12 + ngram_size; ++j) {
2036-
const uint64_t token = get_token(inp, draft, j);
2037-
ngram_t12 <<= 16;
2038-
ngram_t12 |= token;
2039-
}
2040-
ngrams_t12.push_back(ngram_t12);
2034+
// cd = context + dynamic
2035+
std::vector<llama_ngram> ngrams_cd;
2036+
for (int ngram_size_cd = ngram_min; ngram_size_cd <= ngram_max; ++ngram_size_cd) {
2037+
const int ngram_start_cd = inp_size-ngram_size_cd + draft.size()-1;
2038+
llama_ngram ngram_cd = get_token(inp, draft, ngram_start_cd);
2039+
for (int j = ngram_start_cd+1; j < ngram_start_cd + ngram_size_cd; ++j) {
2040+
const llama_ngram token = get_token(inp, draft, j);
2041+
ngram_cd <<= 16;
2042+
ngram_cd |= token;
2043+
}
2044+
ngrams_cd.push_back(ngram_cd);
20412045
}
20422046
if (drafted_token == -1) {
2043-
drafted_token = try_draft(nc_t1, ngrams_t12, part_t3, draft_min_sample_size_t1, draft_min_percent_t1);
2047+
drafted_token = try_draft(nc_context, ngrams_cd, part_static, draft_min_sample_size_lax, draft_min_percent_lax);
20442048
}
20452049
if (drafted_token == -1) {
2046-
drafted_token = try_draft(nc_t2, ngrams_t12, part_t3, draft_min_sample_size_t2, draft_min_percent_t2);
2050+
drafted_token = try_draft(nc_dynamic, ngrams_cd, part_static, draft_min_sample_size_strict, draft_min_percent_strict);
20472051
}
20482052
if (drafted_token == -1) {
2049-
drafted_token = try_draft(nc_t3, ngram_t23);
2053+
drafted_token = try_draft(nc_static, ngram_static);
20502054
}
20512055

20522056
if (drafted_token == -1) {
@@ -2060,14 +2064,13 @@ void llama_ngram_cache_draft(
20602064

20612065
void llama_ngram_cache_save(llama_ngram_cache & ngram_cache, std::string & filename) {
20622066
std::ofstream file_out(filename, std::ios::binary);
2063-
for (std::pair<uint64_t, llama_ngram_cache_part> item : ngram_cache) {
2064-
const uint64_t ngram = item.first;
2067+
for (std::pair<llama_ngram, llama_ngram_cache_part> item : ngram_cache) {
2068+
const llama_ngram ngram = item.first;
20652069
llama_ngram_cache_part token_counts = item.second;
20662070
GGML_ASSERT(!token_counts.empty());
20672071
const int32_t ntokens = token_counts.size();
20682072

2069-
2070-
file_out.write(reinterpret_cast<const char *>(&ngram), sizeof(uint64_t));
2073+
file_out.write(reinterpret_cast<const char *>(&ngram), sizeof(llama_ngram));
20712074
file_out.write(reinterpret_cast<const char *>(&ntokens), sizeof(int32_t));
20722075
for (std::pair<llama_token, int32_t> item2 : token_counts) {
20732076
const llama_token token = item2.first;
@@ -2086,16 +2089,16 @@ llama_ngram_cache llama_ngram_cache_load(std::string & filename) {
20862089
}
20872090
llama_ngram_cache ngram_cache;
20882091

2089-
uint64_t ngram;
2090-
int32_t ntokens;
2092+
llama_ngram ngram;
2093+
int32_t ntokens;
20912094
llama_token token;
2092-
int32_t count;
2095+
int32_t count;
20932096

20942097
char * ngramc = reinterpret_cast<char*>(&ngram);
20952098
char * ntokensc = reinterpret_cast<char*>(&ntokens);
20962099
char * tokenc = reinterpret_cast<char*>(&token);
20972100
char * countc = reinterpret_cast<char*>(&count);
2098-
while(hashmap_file.read(ngramc, sizeof(uint64_t))) {
2101+
while(hashmap_file.read(ngramc, sizeof(llama_ngram))) {
20992102
GGML_ASSERT(hashmap_file.read(ntokensc, sizeof(int32_t)));
21002103
llama_ngram_cache_part token_counts;
21012104

@@ -2113,8 +2116,8 @@ llama_ngram_cache llama_ngram_cache_load(std::string & filename) {
21132116
}
21142117

21152118
void llama_ngram_cache_merge(llama_ngram_cache & ngram_cache_target, llama_ngram_cache & ngram_cache_add) {
2116-
for (std::pair<uint64_t, llama_ngram_cache_part> ngram_part : ngram_cache_add) {
2117-
const uint64_t ngram = ngram_part.first;
2119+
for (std::pair<llama_ngram, llama_ngram_cache_part> ngram_part : ngram_cache_add) {
2120+
const llama_ngram ngram = ngram_part.first;
21182121
llama_ngram_cache_part part = ngram_part.second;
21192122

21202123
llama_ngram_cache::iterator part_merged_it = ngram_cache_target.find(ngram);

common/common.h

Lines changed: 36 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -263,30 +263,52 @@ void dump_kv_cache_view(const llama_kv_cache_view & view, int row_size = 80);
263263
// Dump the KV cache view showing individual sequences in each cell (long output).
264264
void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size = 40);
265265

266-
#define LLAMA_NGRAM_MAX 4
266+
#define LLAMA_NGRAM_MIN 1
267+
#define LLAMA_NGRAM_MAX 4
268+
#define LLAMA_NGRAM_STATIC 2
267269

268270
// Data structures to map n-grams to empirical token probabilities:
269-
typedef std::unordered_map<llama_token, int32_t> llama_ngram_cache_part; // token -> number of times token has been seen
270-
typedef std::unordered_map<uint64_t, llama_ngram_cache_part> llama_ngram_cache; // n-gram -> empirical distribution of following tokens
271-
// n-grams are encoded as 64 bit integers with each of the 4 16 bit sections representing a token id.
272-
// This way no custom hashing function for the n-grams is needed.
271+
typedef uint64_t llama_ngram; // Each of the 4 16 bit sections represents a token id.
272+
typedef std::unordered_map<llama_token, int32_t> llama_ngram_cache_part; // token -> number of times token has been seen
273+
typedef std::unordered_map<llama_ngram, llama_ngram_cache_part> llama_ngram_cache; // n-gram -> empirical distribution of following tokens
274+
275+
static_assert(LLAMA_NGRAM_MAX <= sizeof(llama_ngram)/2, "A 64 bit integer can only hold information for 4 16 bit tokens.");
273276

274277
// Update an ngram cache with tokens.
275-
// ncs = ngram caches: the hashmaps to modify.
276-
// ngram_min/ngram_max: the min/max size of the ngrams in ncs.
277-
// inp_data: the token sequence on which the hashmaps are based.
278-
// nnew: how many new tokens have been appended to inp_data since the last call to this function.
279-
// print_progress: whether to print progress to stderr
278+
// ngram_cache: the cache to modify.
279+
// ngram_min/ngram_max: the min/max size of the ngrams to extract from inp_data.
280+
// inp_data: the token sequence with which to update ngram_cache.
281+
// nnew: how many new tokens have been appended to inp_data since the last call to this function.
282+
// print_progress: whether to print progress to stderr.
280283
//
281284
// In order to get correct results inp_data can ONLY BE APPENDED TO.
282285
// Changes in the middle need a complete rebuild.
283-
void llama_ngram_cache_update(llama_ngram_cache & ngram_cache, int ngram_min, int ngram_max,
284-
std::vector<llama_token> & inp_data, int nnew, bool print_progress);
285-
286+
void llama_ngram_cache_update(
287+
llama_ngram_cache & ngram_cache, int ngram_min, int ngram_max, std::vector<llama_token> & inp_data, int nnew, bool print_progress);
288+
289+
// Try to draft tokens from ngram caches.
290+
// inp: the tokens generated so far.
291+
// draft: the token sequence to draft. Expected to initially contain the previously sampled token.
292+
// n_draft: maximum number of tokens to add to draft.
293+
// ngram_min/gram_max: the min/max size of the ngrams in nc_context and nc_dynamic.
294+
// nc_context: ngram cache based on current context.
295+
// nc_dynamic: ngram cache based on previous user generations.
296+
// nc_static: ngram cache generated from a large text corpus, used for validation.
286297
void llama_ngram_cache_draft(
287298
std::vector<llama_token> & inp, std::vector<llama_token> & draft, int n_draft, int ngram_min, int ngram_max,
288-
llama_ngram_cache & nc_t1, llama_ngram_cache & nc_t2, llama_ngram_cache & nc_t3);
299+
llama_ngram_cache & nc_context, llama_ngram_cache & nc_dynamic, llama_ngram_cache & nc_static);
289300

301+
// Save an ngram cache to a file.
302+
// ngram_cache: the ngram cache to save.
303+
// filename: the path under which to save the ngram cache.
290304
void llama_ngram_cache_save(llama_ngram_cache & ngram_cache, std::string & filename);
305+
306+
// Load an ngram cache saved with llama_ngram_cache_save.
307+
// filename: the path from which to load the ngram cache.
308+
// returns: an ngram cache containing the information saved to filename.
291309
llama_ngram_cache llama_ngram_cache_load(std::string & filename);
310+
311+
// Merge two ngram caches.
312+
// ngram_cache_target: the ngram cache to which to add the information from ngram_cache_add.
313+
// ngram_cache_add: the ngram cache to add to ngram_cache_target.
292314
void llama_ngram_cache_merge(llama_ngram_cache & ngram_cache_target, llama_ngram_cache & ngram_cache_add);

examples/lookup/lookup-create.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@
99
#include <unordered_map>
1010
#include <vector>
1111

12-
constexpr int ngram_size = 2;
13-
1412
int main(int argc, char ** argv){
1513
gpt_params params;
1614

@@ -37,8 +35,8 @@ int main(int argc, char ** argv){
3735

3836

3937
llama_ngram_cache ngram_cache;
40-
llama_ngram_cache_update(ngram_cache, ngram_size, ngram_size, inp, inp.size(), true);
41-
fprintf(stderr, "%s: hashing done, writing file\n", __func__);
38+
llama_ngram_cache_update(ngram_cache, LLAMA_NGRAM_STATIC, LLAMA_NGRAM_STATIC, inp, inp.size(), true);
39+
fprintf(stderr, "%s: hashing done, writing file to %s\n", __func__, params.lookup_cache_static.c_str());
4240

4341
llama_ngram_cache_save(ngram_cache, params.lookup_cache_static);
4442
}

examples/lookup/lookup-merge.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ int main(int argc, char ** argv){
3232
}
3333
}
3434

35+
fprintf(stderr, "lookup-merge: loading file %s\n", args[0].c_str());
3536
llama_ngram_cache ngram_cache_merged = llama_ngram_cache_load(args[0]);
3637

3738
for (size_t i = 1; i < args.size()-1; ++i) {

0 commit comments

Comments
 (0)