Skip to content

Commit c8287f6

Browse files
lookup: use static data from text corpus
1 parent e3d8c5e commit c8287f6

File tree

4 files changed

+220
-6
lines changed

4 files changed

+220
-6
lines changed

Makefile

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -744,6 +744,8 @@ lookahead: examples/lookahead/lookahead.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS
744744
lookup: examples/lookup/lookup.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS)
745745
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
746746
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
747+
$(CXX) $(CXXFLAGS) -c examples/lookup/lookup-create.cpp -o $(call GET_OBJ_FILE, examples/lookup/lookup-create.cpp)
748+
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, examples/lookup/lookup-create.cpp) -o lookup-create $(LDFLAGS)
747749

748750
passkey: examples/passkey/passkey.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS)
749751
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)

examples/lookup/lookup-create.cpp

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
#include "common.h"
2+
#include "ggml.h"
3+
#include "llama.h"
4+
5+
#include <cstdint>
6+
#include <fstream>
7+
#include <iostream>
8+
#include <string>
9+
#include <unordered_map>
10+
#include <vector>
11+
12+
typedef std::unordered_map<llama_token, int32_t> token_hashmap; // token -> number of times token has been seen
13+
typedef std::unordered_map<uint64_t, token_hashmap> all_token_hashmap; // n-gram -> empirical distribution of following tokens
14+
constexpr int ngram_size = 2;
15+
16+
int main(int argc, char ** argv){
17+
gpt_params params;
18+
19+
if (!gpt_params_parse(argc, argv, params)) {
20+
return 1;
21+
}
22+
// init llama.cpp
23+
llama_backend_init(params.numa);
24+
25+
llama_model * model = NULL;
26+
llama_context * ctx = NULL;
27+
28+
// load the model
29+
std::tie(model, ctx) = llama_init_from_gpt_params(params);
30+
GGML_ASSERT(model != nullptr);
31+
32+
// tokenize the prompt
33+
const bool add_bos = llama_should_add_bos_token(model);
34+
35+
const char * static_input_file = "./wikitext-2-raw/wiki.train.raw";
36+
std::ifstream file(static_input_file);
37+
if (!file) {
38+
fprintf(stderr, "error: failed to open file '%s'\n", static_input_file);
39+
exit(1);
40+
}
41+
std::string static_input;
42+
std::copy(std::istreambuf_iterator<char>(file), std::istreambuf_iterator<char>(), back_inserter(static_input));
43+
if (!static_input.empty() && static_input.back() == '\n') {
44+
static_input.pop_back();
45+
}
46+
std::vector<llama_token> inp_static;
47+
inp_static = ::llama_tokenize(ctx, static_input, add_bos, true);
48+
fprintf(stderr, "lookup-create: tokenization done\n");
49+
50+
auto update_hashmaps = [](all_token_hashmap * atc, const llama_token * inp_data, const int inp_size, const int nnew) -> void {
51+
// atcs = all_token_counts: the hashmaps to modify.
52+
// inp_data: the token sequence on which the hashmaps are based.
53+
// inp_size: the current size of inp_data.
54+
// nnew: how many new tokens have been appended to inp_data since the last call to this function.
55+
//
56+
// In order to get correct results inp_data can ONLY BE APPENDED TO.
57+
// Changes in the middle need a complete rebuild.
58+
59+
const int i_start = std::max(inp_size - nnew, ngram_size);
60+
const int64_t t_start_ms = ggml_time_ms();
61+
int percentage_done = 0;
62+
for (int i = i_start; i < inp_size; ++i) {
63+
const int ngram_start = i - ngram_size;
64+
uint64_t ngram = inp_data[ngram_start];
65+
for (int j = ngram_start; j < ngram_start + ngram_size; ++j) {
66+
const uint64_t ngram_part = inp_data[j];
67+
ngram <<= 16;
68+
ngram |= ngram_part;
69+
}
70+
const llama_token token = inp_data[i];
71+
72+
all_token_hashmap::iterator token_counts_it = atc->find(ngram);
73+
if (token_counts_it == atc->end()) {
74+
token_hashmap token_counts;
75+
token_counts.emplace(token, 1);
76+
atc->emplace(ngram, token_counts);
77+
} else {
78+
token_hashmap::iterator tc_it = token_counts_it->second.find(token);
79+
if (tc_it == token_counts_it->second.end()) {
80+
token_counts_it->second.emplace(token, 1);
81+
} else {
82+
tc_it->second++;
83+
}
84+
}
85+
86+
if (i >= inp_size*(percentage_done + 1)/100) {
87+
++percentage_done;
88+
89+
const int64_t t_now_ms = ggml_time_ms();
90+
const int64_t eta_ms = (100 - percentage_done) * (t_now_ms - t_start_ms) / percentage_done;
91+
const int64_t eta_min = eta_ms / (60*1000);
92+
const int64_t eta_s = (eta_ms - eta_min) / 1000;
93+
94+
fprintf(stderr, "lookup-create: %02d%% done, ETA: %02ld:%02ld\n", percentage_done, eta_min, eta_s);
95+
}
96+
}
97+
};
98+
99+
all_token_hashmap atc;
100+
update_hashmaps(&atc, inp_static.data(), inp_static.size(), inp_static.size());
101+
102+
std::ofstream file_out("lookup.bin", std::ios::binary);
103+
for (std::pair<uint64_t, token_hashmap> item : atc) {
104+
const uint64_t ngram = item.first;
105+
token_hashmap token_counts = item.second;
106+
GGML_ASSERT(!token_counts.empty());
107+
const int32_t ntokens = token_counts.size();
108+
109+
110+
file_out.write(reinterpret_cast<const char *>(&ngram), sizeof(uint64_t));
111+
file_out.write(reinterpret_cast<const char *>(&ntokens), sizeof(int32_t));
112+
for (std::pair<llama_token, int32_t> item2 : token_counts) {
113+
const llama_token token = item2.first;
114+
const int32_t count = item2.second;
115+
file_out.write(reinterpret_cast<const char *>(&token), sizeof(llama_token));
116+
file_out.write(reinterpret_cast<const char *>(&count), sizeof(int32_t));
117+
}
118+
}
119+
}

examples/lookup/lookup.cpp

Lines changed: 89 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,13 @@
55
#include <cmath>
66
#include <cstdint>
77
#include <cstdio>
8+
#include <fstream>
89
#include <string>
910
#include <vector>
1011
#include <unordered_map>
1112

1213
// Data structures to map n-grams to empirical token probabilities:
13-
typedef std::unordered_map<llama_token, int> token_hashmap; // token -> number of times token has been seen
14+
typedef std::unordered_map<llama_token, int32_t> token_hashmap; // token -> number of times token has been seen
1415
typedef std::unordered_map<uint64_t, token_hashmap> all_token_hashmap; // n-gram -> empirical distribution of following tokens
1516
// n-grams are encoded as 64 bit integers with each of the 4 16 bit sections representing a token id.
1617
// This way no custom hashing function for the n-grams is needed.
@@ -22,7 +23,7 @@ static_assert(ngram_max <= sizeof(uint64_t)/2, "A 64 bit integer can only hold i
2223

2324
// If sample size or percentage in context are below these thresholds the draft is aborted early:
2425
constexpr float draft_min_sample_size[ngram_max] = { 2, 2, 1, 1};
25-
constexpr float draft_min_percent[ngram_max] = {66, 50, 50, 50};
26+
constexpr float draft_min_percent[ngram_max] = {50, 50, 50, 50};
2627

2728
int main(int argc, char ** argv){
2829
gpt_params params;
@@ -100,12 +101,43 @@ int main(int argc, char ** argv){
100101
};
101102

102103
all_token_hashmap all_token_counts[ngram_max-ngram_min+1];
104+
all_token_hashmap static_all_token_counts;
103105
int64_t t_draft_us = 0;
104106

105107
{
106108
// Fill up hashmaps with tokens from user input:
107109
const int64_t t_start_draft_us = ggml_time_us();
108110
update_hashmaps(all_token_counts, inp.data(), inp.size(), inp.size());
111+
112+
const char * hashmap_file_name = "lookup.bin";
113+
std::ifstream hashmap_file(hashmap_file_name, std::ios::binary);
114+
if (!hashmap_file) {
115+
fprintf(stderr, "error: failed to open file '%s'\n", hashmap_file_name);
116+
exit(1);
117+
}
118+
uint64_t ngram;
119+
int32_t ntokens;
120+
llama_token token;
121+
int32_t count;
122+
123+
char * ngramc = reinterpret_cast<char*>(&ngram);
124+
char * ntokensc = reinterpret_cast<char*>(&ntokens);
125+
char * tokenc = reinterpret_cast<char*>(&token);
126+
char * countc = reinterpret_cast<char*>(&count);
127+
while(hashmap_file.read(ngramc, sizeof(uint64_t))) {
128+
GGML_ASSERT(hashmap_file.read(ntokensc, sizeof(int32_t)));
129+
token_hashmap token_counts;
130+
131+
for (int i = 0; i < ntokens; ++i) {
132+
GGML_ASSERT(hashmap_file.read(tokenc, sizeof(llama_token)));
133+
GGML_ASSERT(hashmap_file.read(countc, sizeof(int32_t)));
134+
token_counts.emplace(token, count);
135+
}
136+
137+
static_all_token_counts.emplace(ngram, token_counts);
138+
}
139+
GGML_ASSERT(hashmap_file.eof());
140+
109141
t_draft_us += ggml_time_us() - t_start_draft_us;
110142
}
111143

@@ -248,6 +280,20 @@ int main(int argc, char ** argv){
248280

249281
while ((int) draft.size()-1 < n_draft) {
250282
bool draft_success = false;
283+
284+
const int static_ngram_start = inp_size-2 + draft.size()-1;
285+
uint64_t static_ngram = get_token(inp, draft, static_ngram_start);
286+
for (int j = static_ngram_start; j < static_ngram_start + 2; ++j) {
287+
const uint64_t ngram_part = get_token(inp, draft, j);
288+
static_ngram <<= 16;
289+
static_ngram |= ngram_part;
290+
}
291+
all_token_hashmap::iterator static_token_counts_it = static_all_token_counts.find(static_ngram);
292+
token_hashmap static_token_counts;
293+
if (static_token_counts_it != static_all_token_counts.end()) {
294+
static_token_counts = static_token_counts_it->second;
295+
}
296+
251297
for (int ngram_size = ngram_max; ngram_size >= ngram_min; --ngram_size) {
252298
if (ngram_size > inp_size) {
253299
continue;
@@ -270,16 +316,21 @@ int main(int argc, char ** argv){
270316
const token_hashmap token_counts = token_counts_it->second;
271317

272318
int max_count = 0;
319+
int max_count_static = 0;
273320
int sum_count = 0;
274321
llama_token max_token = -1;
275322

276323
for (std::pair<llama_token, int> tc : token_counts) {
277324
const llama_token token = tc.first;
278-
const llama_token count = tc.second;
279325

280-
if (count > max_count) {
281-
max_token = token;
282-
max_count = count;
326+
token_hashmap::iterator stc_it = static_token_counts.find(token);
327+
const int32_t count = tc.second;
328+
const int32_t count_static = stc_it != static_token_counts.end() ? 100*stc_it->second : 1;
329+
330+
if (count*count_static > max_count*max_count_static) {
331+
max_token = token;
332+
max_count = count;
333+
max_count_static = count_static;
283334
}
284335
sum_count += count;
285336
}
@@ -299,6 +350,38 @@ int main(int argc, char ** argv){
299350
break;
300351
}
301352

353+
if (!draft_success) {
354+
int max_count = 0;
355+
int sum_count = 0;
356+
llama_token max_token = -1;
357+
358+
for (std::pair<llama_token, int> tc : static_token_counts) {
359+
const llama_token token = tc.first;
360+
const int32_t count = tc.second;
361+
362+
if (count > max_count) {
363+
max_token = token;
364+
max_count = count;
365+
}
366+
sum_count += count;
367+
}
368+
369+
// Skip this candidate if the sample size is too low:
370+
if (sum_count < draft_min_sample_size[2-1]) {
371+
break;
372+
}
373+
// skip this candidate if the empirically most likely token following this token is not likely enough:
374+
if (100*max_count < draft_min_percent[2-1]*sum_count) {
375+
break;
376+
}
377+
378+
LOG(" - draft candidate: token=%d count=%d\n", max_token, max_count);
379+
llama_batch_add(batch_tgt, max_token, n_past + draft.size(), { 0 }, true);
380+
draft.push_back(max_token);
381+
draft_success = true;
382+
break;
383+
}
384+
302385
if (!draft_success) {
303386
break;
304387
}

scripts/get-wikitext-103.sh

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
#!/bin/bash
2+
3+
wget https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-raw-v1.zip
4+
5+
echo "Usage:"
6+
echo ""
7+
echo " ./perplexity -m model.gguf -f wiki.test.raw [other params]"
8+
echo ""
9+
10+
exit 0

0 commit comments

Comments
 (0)