Skip to content

Commit e7d1e38

Browse files
lookup: evaluation tools, use corpus/previous gens
1 parent f9c7ba3 commit e7d1e38

File tree

13 files changed

+758
-61
lines changed

13 files changed

+758
-61
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,9 @@ models-mnt
5757
/llava-cli
5858
/lookahead
5959
/lookup
60+
/lookup-create
61+
/lookup-merge
62+
/lookup-stats
6063
/main
6164
/metal
6265
/passkey

Makefile

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -669,14 +669,17 @@ grammar-parser.o: common/grammar-parser.cpp common/grammar-parser.h
669669
train.o: common/train.cpp common/train.h
670670
$(CXX) $(CXXFLAGS) -c $< -o $@
671671

672+
ngram-cache.o: common/ngram-cache.cpp common/ngram-cache.h
673+
$(CXX) $(CXXFLAGS) -c $< -o $@
674+
672675
libllama.so: llama.o ggml.o $(OBJS)
673676
$(CXX) $(CXXFLAGS) -shared -fPIC -o $@ $^ $(LDFLAGS)
674677

675678
libllama.a: llama.o ggml.o $(OBJS) $(COMMON_DEPS)
676679
ar rcs libllama.a llama.o ggml.o $(OBJS) $(COMMON_DEPS)
677680

678681
clean:
679-
rm -vrf *.o tests/*.o *.so *.a *.dll benchmark-matmult common/build-info.cpp *.dot $(COV_TARGETS) $(BUILD_TARGETS) $(TEST_TARGETS)
682+
rm -vrf *.o tests/*.o *.so *.a *.dll benchmark-matmult lookup-create lookup-merge lookup-stats common/build-info.cpp *.dot $(COV_TARGETS) $(BUILD_TARGETS) $(TEST_TARGETS)
680683
find examples pocs -type f -name "*.o" -delete
681684

682685
#
@@ -806,9 +809,15 @@ lookahead: examples/lookahead/lookahead.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS
806809
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
807810
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
808811

809-
lookup: examples/lookup/lookup.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS)
812+
lookup: examples/lookup/lookup.cpp ggml.o llama.o ngram-cache.o $(COMMON_DEPS) $(OBJS)
810813
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
811814
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
815+
$(CXX) $(CXXFLAGS) -c examples/lookup/lookup-create.cpp -o $(call GET_OBJ_FILE, examples/lookup/lookup-create.cpp)
816+
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, examples/lookup/lookup-create.cpp) -o lookup-create $(LDFLAGS)
817+
$(CXX) $(CXXFLAGS) -c examples/lookup/lookup-merge.cpp -o $(call GET_OBJ_FILE, examples/lookup/lookup-merge.cpp)
818+
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, examples/lookup/lookup-merge.cpp) -o lookup-merge $(LDFLAGS)
819+
$(CXX) $(CXXFLAGS) -c examples/lookup/lookup-stats.cpp -o $(call GET_OBJ_FILE, examples/lookup/lookup-stats.cpp)
820+
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, examples/lookup/lookup-stats.cpp) -o lookup-stats $(LDFLAGS)
812821

813822
passkey: examples/passkey/passkey.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS)
814823
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)

common/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ add_library(${TARGET} STATIC
6262
grammar-parser.cpp
6363
train.h
6464
train.cpp
65+
ngram-cache.h
66+
ngram-cache.cpp
6567
)
6668

6769
if (BUILD_SHARED_LIBS)

common/common.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -948,6 +948,22 @@ static bool gpt_params_find_arg(int argc, char ** argv, gpt_params & params, int
948948
}
949949
return true;
950950
}
951+
if (arg == "-lcs" || arg == "--lookup-cache-static") {
952+
if (++i >= argc) {
953+
invalid_param = true;
954+
return true;
955+
}
956+
params.lookup_cache_static = argv[i];
957+
return true;
958+
}
959+
if (arg == "-lcd" || arg == "--lookup-cache-dynamic") {
960+
if (++i >= argc) {
961+
invalid_param = true;
962+
return true;
963+
}
964+
params.lookup_cache_dynamic = argv[i];
965+
return true;
966+
}
951967
if (arg == "--save-all-logits" || arg == "--kl-divergence-base") {
952968
if (++i >= argc) {
953969
invalid_param = true;
@@ -1410,6 +1426,10 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
14101426
printf(" draft model for speculative decoding\n");
14111427
printf(" -ld LOGDIR, --logdir LOGDIR\n");
14121428
printf(" path under which to save YAML logs (no logging if unset)\n");
1429+
printf(" -lcs FNAME, --lookup-cache-static FNAME\n");
1430+
printf(" path to static lookup cache to use for lookup decoding (not updated by generation)\n");
1431+
printf(" -lcd FNAME, --lookup-cache-dynamic FNAME\n");
1432+
printf(" path to dynamic lookup cache to use for lookup decoding (updated by generation)\n");
14131433
printf(" --override-kv KEY=TYPE:VALUE\n");
14141434
printf(" advanced option to override model metadata by key. may be specified multiple times.\n");
14151435
printf(" types: int, float, bool. example: --override-kv tokenizer.ggml.add_bos_token=bool:false\n");

common/common.h

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -88,18 +88,20 @@ struct gpt_params {
8888
// // sampling parameters
8989
struct llama_sampling_params sparams;
9090

91-
std::string model = "models/7B/ggml-model-f16.gguf"; // model path
92-
std::string model_url = ""; // model url to download
93-
std::string model_draft = ""; // draft model for speculative decoding
94-
std::string model_alias = "unknown"; // model alias
95-
std::string prompt = "";
96-
std::string prompt_file = ""; // store the external prompt file name
97-
std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state
98-
std::string input_prefix = ""; // string to prefix user inputs with
99-
std::string input_suffix = ""; // string to suffix user inputs with
91+
std::string model = "models/7B/ggml-model-f16.gguf"; // model path
92+
std::string model_url = ""; // model url to download
93+
std::string model_draft = ""; // draft model for speculative decoding
94+
std::string model_alias = "unknown"; // model alias
95+
std::string prompt = "";
96+
std::string prompt_file = ""; // store the external prompt file name
97+
std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state
98+
std::string input_prefix = ""; // string to prefix user inputs with
99+
std::string input_suffix = ""; // string to suffix user inputs with
100100
std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
101-
std::string logdir = ""; // directory in which to save YAML log files
102-
std::string logits_file = ""; // file for saving *all* logits
101+
std::string logdir = ""; // directory in which to save YAML log files
102+
std::string lookup_cache_static = ""; // path of static ngram cache file for lookup decoding
103+
std::string lookup_cache_dynamic = ""; // path of dynamic ngram cache file for lookup decoding
104+
std::string logits_file = ""; // file for saving *all* logits
103105

104106
std::vector<llama_model_kv_override> kv_overrides;
105107

common/ngram-cache.cpp

Lines changed: 266 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,266 @@
1+
#include "ngram-cache.h"
2+
#include "log.h"
3+
4+
#include <fstream>
5+
6+
void llama_ngram_cache_update(llama_ngram_cache & ngram_cache, int ngram_min, int ngram_max,
7+
std::vector<llama_token> & inp, int nnew, bool print_progress) {
8+
const int64_t t_start_ms = ggml_time_ms();
9+
const int inp_size = inp.size();
10+
11+
for (int ngram_size = ngram_min; ngram_size <= ngram_max; ++ngram_size) {
12+
const int i_start = std::max(inp_size - nnew, ngram_size);
13+
for (int i = i_start; i < inp_size; ++i) {
14+
const int ngram_start = i - ngram_size;
15+
llama_ngram ngram(&inp[ngram_start], ngram_size);
16+
const llama_token token = inp[i];
17+
18+
llama_ngram_cache::iterator part_it = ngram_cache.find(ngram);
19+
if (part_it == ngram_cache.end()) {
20+
llama_ngram_cache_part part;
21+
part.emplace(token, 1);
22+
ngram_cache.emplace(ngram, part);
23+
} else {
24+
llama_ngram_cache_part::iterator token_count_it = part_it->second.find(token);
25+
if (token_count_it == part_it->second.end()) {
26+
part_it->second.emplace(token, 1);
27+
} else {
28+
token_count_it->second++;
29+
}
30+
}
31+
if (print_progress && i % 10000000 == 0) {
32+
const int64_t t_now_ms = ggml_time_ms();
33+
const int64_t eta_ms = (inp_size - i) * (t_now_ms - t_start_ms) / i;
34+
const int64_t eta_min = eta_ms / (60*1000);
35+
const int64_t eta_s = (eta_ms - 60*1000*eta_min) / 1000;
36+
37+
fprintf(stderr, "%s: %d/%d done, ETA: %02ld:%02ld\n", __func__, i, inp_size, eta_min, eta_s);
38+
}
39+
}
40+
}
41+
}
42+
43+
// Helper function to get a token from the combined, speculative sequence of inp and draft.
44+
static llama_token get_token(const std::vector<llama_token> & inp, const std::vector<llama_token> & draft, const size_t i) {
45+
return i < inp.size() ? inp[i] : draft[1 + i - inp.size()];
46+
};
47+
48+
// If sample size or percentage are below these thresholds the draft is aborted early:
49+
constexpr int draft_min_sample_size_lax[LLAMA_NGRAM_MAX] = { 2, 2, 1, 1};
50+
constexpr int draft_min_percent_lax[LLAMA_NGRAM_MAX] = {66, 50, 50, 50};
51+
constexpr int draft_min_sample_size_strict[LLAMA_NGRAM_MAX] = { 4, 3, 2, 2};
52+
constexpr int draft_min_percent_strict[LLAMA_NGRAM_MAX] = {75, 66, 66, 66};
53+
54+
// Helper function that tries to draft a token from only the static ngram cache:
55+
static llama_token try_draft(llama_ngram_cache & nc_static, const llama_ngram ngram_static) {
56+
llama_ngram_cache::iterator part_static_it = nc_static.find(ngram_static);
57+
if (part_static_it == nc_static.end()) {
58+
return -1;
59+
}
60+
const llama_ngram_cache_part part_static = part_static_it->second;
61+
62+
int max_count_static = 0;
63+
int sum_count_static = 0;
64+
llama_token max_token = -1;
65+
66+
for (std::pair<llama_token, int> token_count_static : part_static) {
67+
const llama_token token = token_count_static.first;
68+
const int32_t count_static = token_count_static.second;
69+
70+
if (count_static > max_count_static) {
71+
max_token = token;
72+
max_count_static = count_static;
73+
}
74+
sum_count_static += count_static;
75+
}
76+
77+
if (sum_count_static < draft_min_sample_size_lax[LLAMA_NGRAM_STATIC-1]) {
78+
return -1;
79+
}
80+
if (100*max_count_static < draft_min_percent_lax[LLAMA_NGRAM_STATIC-1]*sum_count_static) {
81+
return -1;
82+
}
83+
return max_token;
84+
}
85+
86+
// Try to draft a token from primary cache (context/dynamic), validate with static cache:
87+
static llama_token try_draft(
88+
llama_ngram_cache & nc_primary, const std::vector<llama_ngram> & ngrams_primary, llama_ngram_cache_part & part_static,
89+
const int * min_sample_size, const int * min_percent) {
90+
91+
llama_token drafted_token = -1;
92+
93+
for (int i = ngrams_primary.size()-1; i >= 0 && drafted_token == -1; --i) {
94+
const llama_ngram ngram_primary = ngrams_primary[i];
95+
96+
llama_ngram_cache::iterator part_primary_it = nc_primary.find(ngram_primary);
97+
if (part_primary_it == nc_primary.end()) {
98+
continue;
99+
}
100+
const llama_ngram_cache_part part_primary = part_primary_it->second;
101+
102+
int max_count_primary = 0;
103+
int max_count_static = 0;
104+
int sum_count_primary = 0;
105+
llama_token max_token = -1;
106+
107+
for (std::pair<llama_token, int> token_count_primary : part_primary) {
108+
const llama_token token = token_count_primary.first;
109+
110+
llama_ngram_cache_part::iterator token_count_static_it = part_static.find(token);
111+
112+
const int32_t count_primary = token_count_primary.second;
113+
const int32_t count_static = token_count_static_it != part_static.end() ? 100*token_count_static_it->second : 1;
114+
115+
if (count_primary*count_static > max_count_primary*max_count_static) {
116+
max_token = token;
117+
max_count_primary = count_primary;
118+
max_count_static = count_static;
119+
}
120+
sum_count_primary += count_primary;
121+
}
122+
123+
if (sum_count_primary < min_sample_size[i]) {
124+
continue;
125+
}
126+
if (100*max_count_primary < min_percent[i]*sum_count_primary) {
127+
continue;;
128+
}
129+
drafted_token = max_token;
130+
}
131+
132+
return drafted_token;
133+
}
134+
135+
void llama_ngram_cache_draft(
136+
std::vector<llama_token> & inp, std::vector<llama_token> & draft, int n_draft, int ngram_min, int ngram_max,
137+
llama_ngram_cache & nc_context, llama_ngram_cache & nc_dynamic, llama_ngram_cache & nc_static
138+
) {
139+
GGML_ASSERT(draft.size() == 1);
140+
const int inp_size = inp.size();
141+
142+
if (inp_size < LLAMA_NGRAM_STATIC) {
143+
return;
144+
}
145+
146+
while ((int) draft.size()-1 < n_draft) {
147+
llama_token drafted_token = -1;
148+
149+
const int ngram_start_static = inp_size-LLAMA_NGRAM_STATIC + draft.size()-1;
150+
llama_ngram ngram_static;
151+
for (int j = ngram_start_static; j < ngram_start_static + LLAMA_NGRAM_STATIC; ++j) {
152+
ngram_static.tokens[j-ngram_start_static] = get_token(inp, draft, j);
153+
}
154+
llama_ngram_cache::iterator part_static_it = nc_static.find(ngram_static);
155+
llama_ngram_cache_part part_static;
156+
if (part_static_it != nc_static.end()) {
157+
part_static = part_static_it->second;
158+
}
159+
160+
// cd = context + dynamic
161+
std::vector<llama_ngram> ngrams_cd;
162+
for (int ngram_size_cd = ngram_min; ngram_size_cd <= ngram_max; ++ngram_size_cd) {
163+
const int ngram_start_cd = inp_size-ngram_size_cd + draft.size()-1;
164+
llama_ngram ngram_cd;
165+
for (int j = ngram_start_cd; j < ngram_start_cd + ngram_size_cd; ++j) {
166+
ngram_cd.tokens[j-ngram_start_cd] = get_token(inp, draft, j);
167+
}
168+
ngrams_cd.push_back(ngram_cd);
169+
}
170+
if (drafted_token == -1) {
171+
drafted_token = try_draft(nc_context, ngrams_cd, part_static, draft_min_sample_size_lax, draft_min_percent_lax);
172+
}
173+
if (drafted_token == -1) {
174+
drafted_token = try_draft(nc_dynamic, ngrams_cd, part_static, draft_min_sample_size_strict, draft_min_percent_strict);
175+
}
176+
if (drafted_token == -1) {
177+
drafted_token = try_draft(nc_static, ngram_static);
178+
}
179+
180+
if (drafted_token == -1) {
181+
break;
182+
}
183+
184+
LOG(" - draft candidate: token=%d\n", drafted_token);
185+
draft.push_back(drafted_token);
186+
}
187+
};
188+
189+
void llama_ngram_cache_save(llama_ngram_cache & ngram_cache, std::string & filename) {
190+
std::ofstream file_out(filename, std::ios::binary);
191+
for (std::pair<llama_ngram, llama_ngram_cache_part> item : ngram_cache) {
192+
const llama_ngram ngram = item.first;
193+
llama_ngram_cache_part token_counts = item.second;
194+
GGML_ASSERT(!token_counts.empty());
195+
const int32_t ntokens = token_counts.size();
196+
197+
file_out.write(reinterpret_cast<const char *>(&ngram), sizeof(llama_ngram));
198+
file_out.write(reinterpret_cast<const char *>(&ntokens), sizeof(int32_t));
199+
for (std::pair<llama_token, int32_t> item2 : token_counts) {
200+
const llama_token token = item2.first;
201+
const int32_t count = item2.second;
202+
file_out.write(reinterpret_cast<const char *>(&token), sizeof(llama_token));
203+
file_out.write(reinterpret_cast<const char *>(&count), sizeof(int32_t));
204+
}
205+
}
206+
207+
}
208+
209+
llama_ngram_cache llama_ngram_cache_load(std::string & filename) {
210+
std::ifstream hashmap_file(filename, std::ios::binary);
211+
if (!hashmap_file) {
212+
throw std::system_error();
213+
}
214+
llama_ngram_cache ngram_cache;
215+
216+
llama_ngram ngram;
217+
int32_t ntokens;
218+
llama_token token;
219+
int32_t count;
220+
221+
char * ngramc = reinterpret_cast<char*>(&ngram);
222+
char * ntokensc = reinterpret_cast<char*>(&ntokens);
223+
char * tokenc = reinterpret_cast<char*>(&token);
224+
char * countc = reinterpret_cast<char*>(&count);
225+
while(hashmap_file.read(ngramc, sizeof(llama_ngram))) {
226+
GGML_ASSERT(hashmap_file.read(ntokensc, sizeof(int32_t)));
227+
llama_ngram_cache_part token_counts;
228+
229+
for (int i = 0; i < ntokens; ++i) {
230+
GGML_ASSERT(hashmap_file.read(tokenc, sizeof(llama_token)));
231+
GGML_ASSERT(hashmap_file.read(countc, sizeof(int32_t)));
232+
token_counts.emplace(token, count);
233+
}
234+
235+
ngram_cache.emplace(ngram, token_counts);
236+
}
237+
GGML_ASSERT(hashmap_file.eof());
238+
239+
return ngram_cache;
240+
}
241+
242+
void llama_ngram_cache_merge(llama_ngram_cache & ngram_cache_target, llama_ngram_cache & ngram_cache_add) {
243+
for (std::pair<llama_ngram, llama_ngram_cache_part> ngram_part : ngram_cache_add) {
244+
const llama_ngram ngram = ngram_part.first;
245+
llama_ngram_cache_part part = ngram_part.second;
246+
247+
llama_ngram_cache::iterator part_merged_it = ngram_cache_target.find(ngram);
248+
if (part_merged_it == ngram_cache_target.end()) {
249+
ngram_cache_target.emplace(ngram, part);
250+
continue;
251+
}
252+
253+
for (std::pair<llama_token, int32_t> token_count : part) {
254+
const llama_token token = token_count.first;
255+
const int32_t count = token_count.second;
256+
257+
llama_ngram_cache_part::iterator token_count_merged_it = part_merged_it->second.find(token);
258+
if (token_count_merged_it == part_merged_it->second.end()) {
259+
part_merged_it->second.emplace(token, count);
260+
continue;
261+
}
262+
263+
token_count_merged_it->second += count;
264+
}
265+
}
266+
}

0 commit comments

Comments
 (0)