8
8
#include < fstream>
9
9
#include < thread>
10
10
11
- void llama_ngram_cache_update (llama_ngram_cache & ngram_cache, int ngram_min, int ngram_max,
11
+ void common_ngram_cache_update (common_ngram_cache & ngram_cache, int ngram_min, int ngram_max,
12
12
std::vector<llama_token> & inp, int nnew, bool print_progress) {
13
13
const int64_t t_start_ms = ggml_time_ms ();
14
14
const int64_t inp_size = inp.size ();
@@ -20,16 +20,16 @@ void llama_ngram_cache_update(llama_ngram_cache & ngram_cache, int ngram_min, in
20
20
const int64_t i_start = std::max (inp_size - nnew, ngram_size);
21
21
for (int64_t i = i_start; i < inp_size; ++i) {
22
22
const int64_t ngram_start = i - ngram_size;
23
- llama_ngram ngram (&inp[ngram_start], ngram_size);
23
+ common_ngram ngram (&inp[ngram_start], ngram_size);
24
24
const llama_token token = inp[i];
25
25
26
- llama_ngram_cache ::iterator part_it = ngram_cache.find (ngram);
26
+ common_ngram_cache ::iterator part_it = ngram_cache.find (ngram);
27
27
if (part_it == ngram_cache.end ()) {
28
- llama_ngram_cache_part part;
28
+ common_ngram_cache_part part;
29
29
part.emplace (token, 1 );
30
30
ngram_cache.emplace (ngram, part);
31
31
} else {
32
- llama_ngram_cache_part ::iterator token_count_it = part_it->second .find (token);
32
+ common_ngram_cache_part ::iterator token_count_it = part_it->second .find (token);
33
33
if (token_count_it == part_it->second .end ()) {
34
34
part_it->second .emplace (token, 1 );
35
35
} else {
@@ -62,12 +62,12 @@ constexpr int draft_min_sample_size_strict[LLAMA_NGRAM_MAX] = { 4, 3, 2, 2};
62
62
constexpr int draft_min_percent_strict[LLAMA_NGRAM_MAX] = {75 , 66 , 66 , 66 };
63
63
64
64
// Helper function that tries to draft a token from only the static ngram cache:
65
- static llama_token try_draft (llama_ngram_cache & nc_static, const llama_ngram ngram_static) {
66
- llama_ngram_cache ::iterator part_static_it = nc_static.find (ngram_static);
65
+ static llama_token try_draft (common_ngram_cache & nc_static, const common_ngram ngram_static) {
66
+ common_ngram_cache ::iterator part_static_it = nc_static.find (ngram_static);
67
67
if (part_static_it == nc_static.end ()) {
68
68
return -1 ;
69
69
}
70
- const llama_ngram_cache_part part_static = part_static_it->second ;
70
+ const common_ngram_cache_part part_static = part_static_it->second ;
71
71
72
72
int max_count_static = 0 ;
73
73
int sum_count_static = 0 ;
@@ -95,19 +95,19 @@ static llama_token try_draft(llama_ngram_cache & nc_static, const llama_ngram ng
95
95
96
96
// Try to draft a token from primary cache (context/dynamic), validate with static cache:
97
97
static llama_token try_draft (
98
- llama_ngram_cache & nc_primary, const std::vector<llama_ngram > & ngrams_primary, llama_ngram_cache_part & part_static,
98
+ common_ngram_cache & nc_primary, const std::vector<common_ngram > & ngrams_primary, common_ngram_cache_part & part_static,
99
99
const int * min_sample_size, const int * min_percent) {
100
100
101
101
llama_token drafted_token = -1 ;
102
102
103
103
for (int i = ngrams_primary.size ()-1 ; i >= 0 && drafted_token == -1 ; --i) {
104
- const llama_ngram ngram_primary = ngrams_primary[i];
104
+ const common_ngram ngram_primary = ngrams_primary[i];
105
105
106
- llama_ngram_cache ::iterator part_primary_it = nc_primary.find (ngram_primary);
106
+ common_ngram_cache ::iterator part_primary_it = nc_primary.find (ngram_primary);
107
107
if (part_primary_it == nc_primary.end ()) {
108
108
continue ;
109
109
}
110
- const llama_ngram_cache_part part_primary = part_primary_it->second ;
110
+ const common_ngram_cache_part part_primary = part_primary_it->second ;
111
111
112
112
int max_count_primary = 0 ;
113
113
int max_count_static = 0 ;
@@ -117,7 +117,7 @@ static llama_token try_draft(
117
117
for (std::pair<llama_token, int > token_count_primary : part_primary) {
118
118
const llama_token token = token_count_primary.first ;
119
119
120
- llama_ngram_cache_part ::iterator token_count_static_it = part_static.find (token);
120
+ common_ngram_cache_part ::iterator token_count_static_it = part_static.find (token);
121
121
122
122
const int32_t count_primary = token_count_primary.second ;
123
123
const int32_t count_static = token_count_static_it != part_static.end () ? 100 *token_count_static_it->second : 1 ;
@@ -142,9 +142,9 @@ static llama_token try_draft(
142
142
return drafted_token;
143
143
}
144
144
145
- void llama_ngram_cache_draft (
145
+ void common_ngram_cache_draft (
146
146
std::vector<llama_token> & inp, std::vector<llama_token> & draft, int n_draft, int ngram_min, int ngram_max,
147
- llama_ngram_cache & nc_context, llama_ngram_cache & nc_dynamic, llama_ngram_cache & nc_static
147
+ common_ngram_cache & nc_context, common_ngram_cache & nc_dynamic, common_ngram_cache & nc_static
148
148
) {
149
149
GGML_ASSERT (draft.size () == 1 );
150
150
const int inp_size = inp.size ();
@@ -157,21 +157,21 @@ void llama_ngram_cache_draft(
157
157
llama_token drafted_token = -1 ;
158
158
159
159
const int ngram_start_static = inp_size-LLAMA_NGRAM_STATIC + draft.size ()-1 ;
160
- llama_ngram ngram_static;
160
+ common_ngram ngram_static;
161
161
for (int j = ngram_start_static; j < ngram_start_static + LLAMA_NGRAM_STATIC; ++j) {
162
162
ngram_static.tokens [j-ngram_start_static] = get_token (inp, draft, j);
163
163
}
164
- llama_ngram_cache ::iterator part_static_it = nc_static.find (ngram_static);
165
- llama_ngram_cache_part part_static;
164
+ common_ngram_cache ::iterator part_static_it = nc_static.find (ngram_static);
165
+ common_ngram_cache_part part_static;
166
166
if (part_static_it != nc_static.end ()) {
167
167
part_static = part_static_it->second ;
168
168
}
169
169
170
170
// cd = context + dynamic
171
- std::vector<llama_ngram > ngrams_cd;
171
+ std::vector<common_ngram > ngrams_cd;
172
172
for (int ngram_size_cd = ngram_min; ngram_size_cd <= ngram_max; ++ngram_size_cd) {
173
173
const int ngram_start_cd = inp_size-ngram_size_cd + draft.size ()-1 ;
174
- llama_ngram ngram_cd;
174
+ common_ngram ngram_cd;
175
175
for (int j = ngram_start_cd; j < ngram_start_cd + ngram_size_cd; ++j) {
176
176
ngram_cd.tokens [j-ngram_start_cd] = get_token (inp, draft, j);
177
177
}
@@ -196,16 +196,16 @@ void llama_ngram_cache_draft(
196
196
}
197
197
}
198
198
199
- void llama_ngram_cache_save (llama_ngram_cache & ngram_cache, std::string & filename) {
199
+ void common_ngram_cache_save (common_ngram_cache & ngram_cache, std::string & filename) {
200
200
std::ofstream file_out (filename, std::ios::binary);
201
- for (std::pair<llama_ngram, llama_ngram_cache_part > item : ngram_cache) {
202
- const llama_ngram ngram = item.first ;
203
- llama_ngram_cache_part token_counts = item.second ;
201
+ for (std::pair<common_ngram, common_ngram_cache_part > item : ngram_cache) {
202
+ const common_ngram ngram = item.first ;
203
+ common_ngram_cache_part token_counts = item.second ;
204
204
GGML_ASSERT (!token_counts.empty ());
205
205
const int32_t ntokens = token_counts.size ();
206
206
GGML_ASSERT (ntokens > 0 );
207
207
208
- file_out.write (reinterpret_cast <const char *>(&ngram), sizeof (llama_ngram ));
208
+ file_out.write (reinterpret_cast <const char *>(&ngram), sizeof (common_ngram ));
209
209
file_out.write (reinterpret_cast <const char *>(&ntokens), sizeof (int32_t ));
210
210
for (std::pair<llama_token, int32_t > item2 : token_counts) {
211
211
const llama_token token = item2.first ;
@@ -219,14 +219,14 @@ void llama_ngram_cache_save(llama_ngram_cache & ngram_cache, std::string & filen
219
219
220
220
}
221
221
222
- llama_ngram_cache llama_ngram_cache_load (std::string & filename) {
222
+ common_ngram_cache common_ngram_cache_load (std::string & filename) {
223
223
std::ifstream hashmap_file (filename, std::ios::binary);
224
224
if (!hashmap_file) {
225
225
throw std::ifstream::failure (" Unable to open file " + filename);
226
226
}
227
- llama_ngram_cache ngram_cache;
227
+ common_ngram_cache ngram_cache;
228
228
229
- llama_ngram ngram;
229
+ common_ngram ngram;
230
230
int32_t ntokens;
231
231
llama_token token;
232
232
int32_t count;
@@ -235,11 +235,11 @@ llama_ngram_cache llama_ngram_cache_load(std::string & filename) {
235
235
char * ntokensc = reinterpret_cast <char *>(&ntokens);
236
236
char * tokenc = reinterpret_cast <char *>(&token);
237
237
char * countc = reinterpret_cast <char *>(&count);
238
- while (hashmap_file.read (ngramc, sizeof (llama_ngram ))) {
238
+ while (hashmap_file.read (ngramc, sizeof (common_ngram ))) {
239
239
GGML_ASSERT (!hashmap_file.eof ());
240
240
GGML_ASSERT (hashmap_file.read (ntokensc, sizeof (int32_t )));
241
241
GGML_ASSERT (ntokens > 0 );
242
- llama_ngram_cache_part token_counts;
242
+ common_ngram_cache_part token_counts;
243
243
244
244
for (int i = 0 ; i < ntokens; ++i) {
245
245
GGML_ASSERT (!hashmap_file.eof ());
@@ -257,12 +257,12 @@ llama_ngram_cache llama_ngram_cache_load(std::string & filename) {
257
257
return ngram_cache;
258
258
}
259
259
260
- void llama_ngram_cache_merge (llama_ngram_cache & ngram_cache_target, llama_ngram_cache & ngram_cache_add) {
261
- for (std::pair<llama_ngram, llama_ngram_cache_part > ngram_part : ngram_cache_add) {
262
- const llama_ngram ngram = ngram_part.first ;
263
- llama_ngram_cache_part part = ngram_part.second ;
260
+ void common_ngram_cache_merge (common_ngram_cache & ngram_cache_target, common_ngram_cache & ngram_cache_add) {
261
+ for (std::pair<common_ngram, common_ngram_cache_part > ngram_part : ngram_cache_add) {
262
+ const common_ngram ngram = ngram_part.first ;
263
+ common_ngram_cache_part part = ngram_part.second ;
264
264
265
- llama_ngram_cache ::iterator part_merged_it = ngram_cache_target.find (ngram);
265
+ common_ngram_cache ::iterator part_merged_it = ngram_cache_target.find (ngram);
266
266
if (part_merged_it == ngram_cache_target.end ()) {
267
267
ngram_cache_target.emplace (ngram, part);
268
268
continue ;
@@ -273,7 +273,7 @@ void llama_ngram_cache_merge(llama_ngram_cache & ngram_cache_target, llama_ngram
273
273
const int32_t count = token_count.second ;
274
274
GGML_ASSERT (count > 0 );
275
275
276
- llama_ngram_cache_part ::iterator token_count_merged_it = part_merged_it->second .find (token);
276
+ common_ngram_cache_part ::iterator token_count_merged_it = part_merged_it->second .find (token);
277
277
if (token_count_merged_it == part_merged_it->second .end ()) {
278
278
part_merged_it->second .emplace (token, count);
279
279
continue ;
0 commit comments