Skip to content

Commit b7efe2b

Browse files
ggerganovNeoZhangJianyu
authored andcommitted
examples : fix add_special conditions (ggml-org#11311)
1 parent 2a64a75 commit b7efe2b

File tree

2 files changed

+11
-7
lines changed

2 files changed

+11
-7
lines changed

examples/run/run.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -729,10 +729,12 @@ static int apply_chat_template(LlamaData & llama_data, const bool append) {
729729

730730
// Function to tokenize the prompt
731731
static int tokenize_prompt(const llama_vocab * vocab, const std::string & prompt,
732-
std::vector<llama_token> & prompt_tokens) {
733-
const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, true, true);
732+
std::vector<llama_token> & prompt_tokens, const LlamaData & llama_data) {
733+
const bool is_first = llama_get_kv_cache_used_cells(llama_data.context.get()) == 0;
734+
735+
const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, is_first, true);
734736
prompt_tokens.resize(n_prompt_tokens);
735-
if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), true,
737+
if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), is_first,
736738
true) < 0) {
737739
printe("failed to tokenize the prompt\n");
738740
return -1;
@@ -778,7 +780,7 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str
778780
const llama_vocab * vocab = llama_model_get_vocab(llama_data.model.get());
779781

780782
std::vector<llama_token> tokens;
781-
if (tokenize_prompt(vocab, prompt, tokens) < 0) {
783+
if (tokenize_prompt(vocab, prompt, tokens, llama_data) < 0) {
782784
return 1;
783785
}
784786

examples/simple-chat/simple-chat.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,13 +95,15 @@ int main(int argc, char ** argv) {
9595
llama_sampler_chain_add(smpl, llama_sampler_init_dist(LLAMA_DEFAULT_SEED));
9696

9797
// helper function to evaluate a prompt and generate a response
98-
auto generate = [&](const std::string & prompt, bool is_first) {
98+
auto generate = [&](const std::string & prompt) {
9999
std::string response;
100100

101+
const bool is_first = llama_get_kv_cache_used_cells(ctx) == 0;
102+
101103
// tokenize the prompt
102104
const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, is_first, true);
103105
std::vector<llama_token> prompt_tokens(n_prompt_tokens);
104-
if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), llama_get_kv_cache_used_cells(ctx) == 0, true) < 0) {
106+
if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), is_first, true) < 0) {
105107
GGML_ABORT("failed to tokenize the prompt\n");
106108
}
107109

@@ -180,7 +182,7 @@ int main(int argc, char ** argv) {
180182

181183
// generate a response
182184
printf("\033[33m");
183-
std::string response = generate(prompt, prev_len == 0);
185+
std::string response = generate(prompt);
184186
printf("\n\033[0m");
185187

186188
// add the response to the messages

0 commit comments

Comments
 (0)