Skip to content

Commit 1fa5a4b

Browse files
author
ochafik
committed
grammars: early exit when no next_candidates to reject
grammars: cache decoded tokens grammars: faster llama_grammar_copy grammars: fix bad merge grammars: keep llama_grammar_copy non-quadratic optim for later grammars: move token caches to llama_context grammars: cache codepoints in llama_new_context_with_model grammar: nit (layout) grammars: nits (revert const grammar sig, fix comment) Update llama.cpp Co-authored-by: Clint Herron <[email protected]> grammars: mutex-guarded lazy caching of token pieces in llama_sample_grammar grammars: remove early exit --> ggml-org#7370
1 parent 059031b commit 1fa5a4b

File tree

1 file changed

+33
-5
lines changed

1 file changed

+33
-5
lines changed

llama.cpp

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2269,6 +2269,12 @@ struct llama_context {
22692269
// control vectors
22702270
struct llama_control_vector cvec;
22712271

2272+
// caching token pieces & their decoded codepoints.
2273+
std::mutex token_cache_mutex;
2274+
std::vector<std::string> token_pieces;
2275+
std::vector<std::pair<std::vector<uint32_t>, llama_partial_utf8>>
2276+
token_codepoints_without_partial_utf8_prefix;
2277+
22722278
#ifdef GGML_USE_MPI
22732279
ggml_mpi_context * ctx_mpi = NULL;
22742280
#endif
@@ -13833,21 +13839,41 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c
1383313839
}
1383413840
}
1383513841

13842+
{
13843+
// cache tokens & their decoded codepoints (for common case where there's no partial utf8 prefix bytes) for grammar-constrained sampling.
13844+
std::unique_lock<std::mutex> lock(ctx->token_cache_mutex);
13845+
if (ctx->token_pieces.empty()) {
13846+
auto n_vocab = llama_n_vocab(llama_get_model(ctx));
13847+
ctx->token_codepoints_without_partial_utf8_prefix.resize(n_vocab);
13848+
ctx->token_pieces.resize(n_vocab);
13849+
for (llama_token id = 0; id < n_vocab; ++id) {
13850+
const std::string piece = llama_token_to_piece(ctx, id, false);
13851+
ctx->token_pieces[id] = piece;
13852+
ctx->token_codepoints_without_partial_utf8_prefix[id] = decode_utf8(piece, {0, 0});
13853+
}
13854+
}
13855+
}
13856+
13857+
// Store decoded codepoints when they are not cached (happens when there's a partial utf8 string prefix).
1383613858
std::vector<std::pair<std::vector<uint32_t>, llama_partial_utf8>> candidates_decoded;
13837-
candidates_decoded.reserve(candidates->size);
13859+
if (grammar->partial_utf8.n_remain > 0) {
13860+
candidates_decoded.reserve(candidates->size);
13861+
}
1383813862
std::vector<llama_grammar_candidate> candidates_grammar;
1383913863
candidates_grammar.reserve(candidates->size);
1384013864

1384113865
for (size_t i = 0; i < candidates->size; ++i) {
1384213866
const llama_token id = candidates->data[i].id;
13843-
const std::string piece = llama_token_to_piece(ctx, id, false);
13844-
13867+
const auto & piece = ctx->token_pieces[id];
1384513868
if (llama_token_is_eog(&ctx->model, id)) {
1384613869
if (!allow_eog) {
1384713870
candidates->data[i].logit = -INFINITY;
1384813871
}
1384913872
} else if (piece.empty() || piece[0] == 0) {
1385013873
candidates->data[i].logit = -INFINITY;
13874+
} else if (grammar->partial_utf8.n_remain == 0){
13875+
const auto & decoded = ctx->token_codepoints_without_partial_utf8_prefix.at(id);
13876+
candidates_grammar.push_back({ i, decoded.first.data(), decoded.second });
1385113877
} else {
1385213878
candidates_decoded.push_back(decode_utf8(piece, grammar->partial_utf8));
1385313879
candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second });
@@ -14040,10 +14066,12 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar
1404014066
GGML_ASSERT(false);
1404114067
}
1404214068

14043-
const std::string piece = llama_token_to_piece(ctx, token, false);
14069+
const auto & piece = ctx->token_pieces.at(token);
1404414070

1404514071
// Note terminating 0 in decoded string
14046-
const auto decoded = decode_utf8(piece, grammar->partial_utf8);
14072+
const auto decoded = grammar->partial_utf8.n_remain == 0
14073+
? ctx->token_codepoints_without_partial_utf8_prefix[token]
14074+
: decode_utf8(piece, grammar->partial_utf8);
1404714075
const auto & code_points = decoded.first;
1404814076
std::vector<std::vector<const llama_grammar_element *>> tmp_new_stacks;
1404914077
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {

0 commit comments

Comments
 (0)