Skip to content

Commit bc7a3c0

Browse files
ngram: uint64_t struct with 4x int32_t
1 parent 1de9b5d commit bc7a3c0

File tree

2 files changed

+46
-20
lines changed

2 files changed

+46
-20
lines changed

common/common.cpp

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1879,12 +1879,7 @@ void llama_ngram_cache_update(llama_ngram_cache & ngram_cache, int ngram_min, in
18791879
const int i_start = std::max(inp_size - nnew, ngram_size);
18801880
for (int i = i_start; i < inp_size; ++i) {
18811881
const int ngram_start = i - ngram_size;
1882-
llama_ngram ngram = inp[ngram_start];
1883-
for (int j = ngram_start+1; j < ngram_start + ngram_size; ++j) { // FIXME
1884-
const llama_ngram ngram_part = inp[j];
1885-
ngram <<= 16;
1886-
ngram |= ngram_part;
1887-
}
1882+
llama_ngram ngram(&inp[ngram_start], ngram_size);
18881883
const llama_token token = inp[i];
18891884

18901885
llama_ngram_cache::iterator part_it = ngram_cache.find(ngram);
@@ -2019,11 +2014,9 @@ void llama_ngram_cache_draft(
20192014
llama_token drafted_token = -1;
20202015

20212016
const int ngram_start_static = inp_size-LLAMA_NGRAM_STATIC + draft.size()-1;
2022-
llama_ngram ngram_static = get_token(inp, draft, ngram_start_static);
2023-
for (int j = ngram_start_static+1; j < ngram_start_static + LLAMA_NGRAM_STATIC; ++j) {
2024-
const llama_ngram token = get_token(inp, draft, j);
2025-
ngram_static <<= 16;
2026-
ngram_static |= token;
2017+
llama_ngram ngram_static;
2018+
for (int j = ngram_start_static; j < ngram_start_static + LLAMA_NGRAM_STATIC; ++j) {
2019+
ngram_static.tokens[j-ngram_start_static] = get_token(inp, draft, j);
20272020
}
20282021
llama_ngram_cache::iterator part_static_it = nc_static.find(ngram_static);
20292022
llama_ngram_cache_part part_static;
@@ -2035,11 +2028,9 @@ void llama_ngram_cache_draft(
20352028
std::vector<llama_ngram> ngrams_cd;
20362029
for (int ngram_size_cd = ngram_min; ngram_size_cd <= ngram_max; ++ngram_size_cd) {
20372030
const int ngram_start_cd = inp_size-ngram_size_cd + draft.size()-1;
2038-
llama_ngram ngram_cd = get_token(inp, draft, ngram_start_cd);
2039-
for (int j = ngram_start_cd+1; j < ngram_start_cd + ngram_size_cd; ++j) {
2040-
const llama_ngram token = get_token(inp, draft, j);
2041-
ngram_cd <<= 16;
2042-
ngram_cd |= token;
2031+
llama_ngram ngram_cd;
2032+
for (int j = ngram_start_cd; j < ngram_start_cd + ngram_size_cd; ++j) {
2033+
ngram_cd.tokens[j-ngram_start_cd] = get_token(inp, draft, j);
20432034
}
20442035
ngrams_cd.push_back(ngram_cd);
20452036
}

common/common.h

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -268,11 +268,46 @@ void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size = 40
268268
#define LLAMA_NGRAM_STATIC 2
269269

270270
// Data structures to map n-grams to empirical token probabilities:
271-
typedef uint64_t llama_ngram; // Each of the 4 16 bit sections represents a token id.
272-
typedef std::unordered_map<llama_token, int32_t> llama_ngram_cache_part; // token -> number of times token has been seen
273-
typedef std::unordered_map<llama_ngram, llama_ngram_cache_part> llama_ngram_cache; // n-gram -> empirical distribution of following tokens
274271

275-
static_assert(LLAMA_NGRAM_MAX <= sizeof(llama_ngram)/2, "A 64 bit integer can only hold information for 4 16 bit tokens.");
272+
struct llama_ngram {
273+
llama_token tokens[LLAMA_NGRAM_MAX];
274+
275+
llama_ngram() {
276+
memset(tokens, 0, sizeof(tokens));
277+
}
278+
279+
llama_ngram(const llama_token * input, const int ngram_size) {
280+
for (int i = 0; i < LLAMA_NGRAM_MAX; ++i) {
281+
tokens[i] = i < ngram_size ? input[i] : 0;
282+
}
283+
}
284+
285+
bool operator==(const llama_ngram & other) const {
286+
for (int i = 0; i < LLAMA_NGRAM_MAX; ++i) {
287+
if (tokens[i] != other.tokens[i]) {
288+
return false;
289+
}
290+
}
291+
return true;
292+
}
293+
};
294+
295+
struct llama_ngram_hash_function {
296+
size_t operator()(const llama_ngram & ngram) const {
297+
size_t hash = 0;
298+
for (int i = 0; i < LLAMA_NGRAM_MAX; ++i) {
299+
hash ^= std::hash<llama_token>{}(ngram.tokens[i]);
300+
}
301+
return hash;
302+
}
303+
};
304+
305+
// token -> number of times token has been seen
306+
typedef std::unordered_map<llama_token, int32_t> llama_ngram_cache_part;
307+
308+
// n-gram -> empirical distribution of following tokens
309+
typedef std::unordered_map<llama_ngram, llama_ngram_cache_part, llama_ngram_hash_function> llama_ngram_cache;
310+
276311

277312
// Update an ngram cache with tokens.
278313
// ngram_cache: the cache to modify.

0 commit comments

Comments
 (0)