@@ -2269,6 +2269,12 @@ struct llama_context {
2269
2269
// control vectors
2270
2270
struct llama_control_vector cvec;
2271
2271
2272
+ // caching token pieces & their decoded codepoints.
2273
+ std::mutex token_cache_mutex;
2274
+ std::vector<std::string> token_pieces;
2275
+ std::vector<std::pair<std::vector<uint32_t>, llama_partial_utf8>>
2276
+ token_codepoints_without_partial_utf8_prefix;
2277
+
2272
2278
#ifdef GGML_USE_MPI
2273
2279
ggml_mpi_context * ctx_mpi = NULL;
2274
2280
#endif
@@ -13833,21 +13839,41 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c
13833
13839
}
13834
13840
}
13835
13841
13842
+ {
13843
+ // cache tokens & their decoded codepoints (for common case where there's no partial utf8 prefix bytes) for grammar-constrained sampling.
13844
+ std::unique_lock<std::mutex> lock(ctx->token_cache_mutex);
13845
+ if (ctx->token_pieces.empty()) {
13846
+ auto n_vocab = llama_n_vocab(llama_get_model(ctx));
13847
+ ctx->token_codepoints_without_partial_utf8_prefix.resize(n_vocab);
13848
+ ctx->token_pieces.resize(n_vocab);
13849
+ for (llama_token id = 0; id < n_vocab; ++id) {
13850
+ const std::string piece = llama_token_to_piece(ctx, id, false);
13851
+ ctx->token_pieces[id] = piece;
13852
+ ctx->token_codepoints_without_partial_utf8_prefix[id] = decode_utf8(piece, {0, 0});
13853
+ }
13854
+ }
13855
+ }
13856
+
13857
+ // Store decoded codepoints when they are not cached (happens when there's a partial utf8 string prefix).
13836
13858
std::vector<std::pair<std::vector<uint32_t>, llama_partial_utf8>> candidates_decoded;
13837
- candidates_decoded.reserve(candidates->size);
13859
+ if (grammar->partial_utf8.n_remain > 0) {
13860
+ candidates_decoded.reserve(candidates->size);
13861
+ }
13838
13862
std::vector<llama_grammar_candidate> candidates_grammar;
13839
13863
candidates_grammar.reserve(candidates->size);
13840
13864
13841
13865
for (size_t i = 0; i < candidates->size; ++i) {
13842
13866
const llama_token id = candidates->data[i].id;
13843
- const std::string piece = llama_token_to_piece(ctx, id, false);
13844
-
13867
+ const auto & piece = ctx->token_pieces[id];
13845
13868
if (llama_token_is_eog(&ctx->model, id)) {
13846
13869
if (!allow_eog) {
13847
13870
candidates->data[i].logit = -INFINITY;
13848
13871
}
13849
13872
} else if (piece.empty() || piece[0] == 0) {
13850
13873
candidates->data[i].logit = -INFINITY;
13874
+ } else if (grammar->partial_utf8.n_remain == 0){
13875
+ const auto & decoded = ctx->token_codepoints_without_partial_utf8_prefix.at(id);
13876
+ candidates_grammar.push_back({ i, decoded.first.data(), decoded.second });
13851
13877
} else {
13852
13878
candidates_decoded.push_back(decode_utf8(piece, grammar->partial_utf8));
13853
13879
candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second });
@@ -14040,10 +14066,12 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar
14040
14066
GGML_ASSERT(false);
14041
14067
}
14042
14068
14043
- const std::string piece = llama_token_to_piece( ctx, token, false );
14069
+ const auto & piece = ctx->token_pieces.at( token);
14044
14070
14045
14071
// Note terminating 0 in decoded string
14046
- const auto decoded = decode_utf8(piece, grammar->partial_utf8);
14072
+ const auto decoded = grammar->partial_utf8.n_remain == 0
14073
+ ? ctx->token_codepoints_without_partial_utf8_prefix[token]
14074
+ : decode_utf8(piece, grammar->partial_utf8);
14047
14075
const auto & code_points = decoded.first;
14048
14076
std::vector<std::vector<const llama_grammar_element *>> tmp_new_stacks;
14049
14077
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
0 commit comments