Skip to content

Commit 48e6b92

Browse files
ngxsonggerganov
andauthored
Add chat template support for llama-cli (#8068)
* add chat template support for llama-cli * add help message * server: simplify format_chat * more consistent naming * improve * add llama_chat_format_example * fix server * code style * code style * Update examples/main/main.cpp Co-authored-by: Georgi Gerganov <[email protected]> --------- Co-authored-by: Georgi Gerganov <[email protected]>
1 parent 3791ad2 commit 48e6b92

File tree

7 files changed

+154
-49
lines changed

7 files changed

+154
-49
lines changed

common/common.cpp

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1444,7 +1444,10 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
14441444
options.push_back({ "main", " --cfg-negative-prompt-file FNAME",
14451445
"negative prompt file to use for guidance" });
14461446
options.push_back({ "main", " --cfg-scale N", "strength of guidance (default: %.1f, 1.0 = disable)", (double)sparams.cfg_scale });
1447-
1447+
options.push_back({ "main", " --chat-template JINJA_TEMPLATE",
1448+
"set custom jinja chat template (default: template taken from model's metadata)\n"
1449+
"only commonly used templates are accepted:\n"
1450+
"https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template" });
14481451
options.push_back({ "grammar" });
14491452
options.push_back({ "*", " --grammar GRAMMAR", "BNF-like grammar to constrain generations (see samples in grammars/ dir) (default: '%s')", sparams.grammar.c_str() });
14501453
options.push_back({ "*", " --grammar-file FNAME", "file to read grammar from" });
@@ -2604,12 +2607,67 @@ bool llama_should_add_bos_token(const llama_model * model) {
26042607
return add_bos != -1 ? bool(add_bos) : (llama_vocab_type(model) == LLAMA_VOCAB_TYPE_SPM);
26052608
}
26062609

2610+
//
2611+
// Chat template utils
2612+
//
2613+
26072614
bool llama_chat_verify_template(const std::string & tmpl) {
26082615
llama_chat_message chat[] = {{"user", "test"}};
26092616
int res = llama_chat_apply_template(nullptr, tmpl.c_str(), chat, 1, true, nullptr, 0);
26102617
return res >= 0;
26112618
}
26122619

2620+
std::string llama_chat_apply_template(const struct llama_model * model,
2621+
const std::string & tmpl,
2622+
const std::vector<llama_chat_msg> & msgs,
2623+
bool add_ass) {
2624+
int alloc_size = 0;
2625+
std::vector<llama_chat_message> chat;
2626+
for (auto & msg : msgs) {
2627+
chat.push_back({msg.role.c_str(), msg.content.c_str()});
2628+
alloc_size += (msg.role.size() + msg.content.size()) * 1.25;
2629+
}
2630+
2631+
const char * ptr_tmpl = tmpl.empty() ? nullptr : tmpl.c_str();
2632+
std::vector<char> buf(alloc_size);
2633+
2634+
// run the first time to get the total output length
2635+
int32_t res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size());
2636+
2637+
// if it turns out that our buffer is too small, we resize it
2638+
if ((size_t) res > buf.size()) {
2639+
buf.resize(res);
2640+
res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size());
2641+
}
2642+
2643+
std::string formatted_chat(buf.data(), res);
2644+
return formatted_chat;
2645+
}
2646+
2647+
std::string llama_chat_format_single(const struct llama_model * model,
2648+
const std::string & tmpl,
2649+
const std::vector<llama_chat_msg> & past_msg,
2650+
const llama_chat_msg & new_msg,
2651+
bool add_ass) {
2652+
auto fmt_past_msg = llama_chat_apply_template(model, tmpl, past_msg, false);
2653+
std::vector<llama_chat_msg> chat_new(past_msg);
2654+
chat_new.push_back(new_msg);
2655+
auto fmt_new_msg = llama_chat_apply_template(model, tmpl, chat_new, add_ass);
2656+
auto formatted = fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size());
2657+
return formatted;
2658+
}
2659+
2660+
std::string llama_chat_format_example(const struct llama_model * model,
2661+
const std::string & tmpl) {
2662+
std::vector<llama_chat_msg> msgs = {
2663+
{"system", "You are a helpful assistant"},
2664+
{"user", "Hello"},
2665+
{"assistant", "Hi there"},
2666+
{"user", "How are you?"},
2667+
};
2668+
return llama_chat_apply_template(model, tmpl, msgs, true);
2669+
}
2670+
26132671
//
26142672
// KV cache utils
26152673
//

common/common.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,9 +365,32 @@ bool llama_should_add_bos_token(const llama_model * model);
365365
// Chat template utils
366366
//
367367

368+
// same with llama_chat_message, but uses std::string
369+
struct llama_chat_msg {
370+
std::string role;
371+
std::string content;
372+
};
373+
368374
// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
369375
bool llama_chat_verify_template(const std::string & tmpl);
370376

377+
// CPP wrapper for llama_chat_apply_template
378+
std::string llama_chat_apply_template(const struct llama_model * model,
379+
const std::string & tmpl,
380+
const std::vector<llama_chat_msg> & chat,
381+
bool add_ass);
382+
383+
// Format single message, while taking into account the position of that message in chat history
384+
std::string llama_chat_format_single(const struct llama_model * model,
385+
const std::string & tmpl,
386+
const std::vector<llama_chat_msg> & past_msg,
387+
const llama_chat_msg & new_msg,
388+
bool add_ass);
389+
390+
// Returns an example of formatted chat
391+
std::string llama_chat_format_example(const struct llama_model * model,
392+
const std::string & tmpl);
393+
371394
//
372395
// KV cache utils
373396
//

examples/main/main.cpp

Lines changed: 43 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,12 @@ static std::ostringstream * g_output_ss;
3939
static std::vector<llama_token> * g_output_tokens;
4040
static bool is_interacting = false;
4141

42-
static bool file_exists(const std::string &path) {
42+
static bool file_exists(const std::string & path) {
4343
std::ifstream f(path.c_str());
4444
return f.good();
4545
}
4646

47-
static bool file_is_empty(const std::string &path) {
47+
static bool file_is_empty(const std::string & path) {
4848
std::ifstream f;
4949
f.exceptions(std::ifstream::failbit | std::ifstream::badbit);
5050
f.open(path.c_str(), std::ios::in | std::ios::binary | std::ios::ate);
@@ -117,6 +117,14 @@ static void llama_log_callback_logTee(ggml_log_level level, const char * text, v
117117
LOG_TEE("%s", text);
118118
}
119119

120+
static std::string chat_add_and_format(struct llama_model * model, std::vector<llama_chat_msg> & chat_msgs, std::string role, std::string content) {
121+
llama_chat_msg new_msg{role, content};
122+
auto formatted = llama_chat_format_single(
123+
model, g_params->chat_template, chat_msgs, new_msg, role == "user");
124+
chat_msgs.push_back({role, content});
125+
return formatted;
126+
}
127+
120128
int main(int argc, char ** argv) {
121129
gpt_params params;
122130
g_params = &params;
@@ -190,6 +198,7 @@ int main(int argc, char ** argv) {
190198
llama_model * model;
191199
llama_context * ctx;
192200
llama_context * ctx_guidance = NULL;
201+
std::vector<llama_chat_msg> chat_msgs;
193202
g_model = &model;
194203
g_ctx = &ctx;
195204

@@ -215,6 +224,8 @@ int main(int argc, char ** argv) {
215224
__func__, n_ctx_train, n_ctx);
216225
}
217226

227+
LOG_TEE("%s: chat template example: %s\n", __func__, llama_chat_format_example(model, params.chat_template).c_str());
228+
218229
// print system information
219230
{
220231
LOG_TEE("\n");
@@ -249,16 +260,21 @@ int main(int argc, char ** argv) {
249260

250261
std::vector<llama_token> embd_inp;
251262

252-
if (params.interactive_first || !params.prompt.empty() || session_tokens.empty()) {
253-
LOG("tokenize the prompt\n");
254-
embd_inp = ::llama_tokenize(ctx, params.prompt, true, true);
255-
} else {
256-
LOG("use session tokens\n");
257-
embd_inp = session_tokens;
258-
}
263+
{
264+
auto prompt = params.conversation
265+
? chat_add_and_format(model, chat_msgs, "system", params.prompt) // format the system prompt in conversation mode
266+
: params.prompt;
267+
if (params.interactive_first || !params.prompt.empty() || session_tokens.empty()) {
268+
LOG("tokenize the prompt\n");
269+
embd_inp = ::llama_tokenize(ctx, prompt, true, true);
270+
} else {
271+
LOG("use session tokens\n");
272+
embd_inp = session_tokens;
273+
}
259274

260-
LOG("prompt: \"%s\"\n", log_tostr(params.prompt));
261-
LOG("tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str());
275+
LOG("prompt: \"%s\"\n", log_tostr(prompt));
276+
LOG("tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str());
277+
}
262278

263279
// Should not run without any tokens
264280
if (embd_inp.empty()) {
@@ -478,6 +494,7 @@ int main(int argc, char ** argv) {
478494
std::vector<int> input_tokens; g_input_tokens = &input_tokens;
479495
std::vector<int> output_tokens; g_output_tokens = &output_tokens;
480496
std::ostringstream output_ss; g_output_ss = &output_ss;
497+
std::ostringstream assistant_ss; // for storing current assistant message, used in conversation mode
481498

482499
// the first thing we will do is to output the prompt, so set color accordingly
483500
console::set_display(console::prompt);
@@ -793,11 +810,18 @@ int main(int argc, char ** argv) {
793810
is_antiprompt = true;
794811
}
795812

813+
chat_add_and_format(model, chat_msgs, "system", assistant_ss.str());
796814
is_interacting = true;
797815
printf("\n");
798816
}
799817
}
800818

819+
// if current token is not EOG, we add it to current assistant message
820+
if (params.conversation) {
821+
auto id = llama_sampling_last(ctx_sampling);
822+
assistant_ss << llama_token_to_piece(ctx, id, false);
823+
}
824+
801825
if (n_past > 0 && is_interacting) {
802826
LOG("waiting for user input\n");
803827

@@ -848,8 +872,12 @@ int main(int argc, char ** argv) {
848872
string_process_escapes(buffer);
849873
}
850874

875+
std::string user_inp = params.conversation
876+
? chat_add_and_format(model, chat_msgs, "user", std::move(buffer))
877+
: std::move(buffer);
878+
// TODO: one inconvenient of current chat template implementation is that we can't distinguish between user input and special tokens (prefix/postfix)
851879
const auto line_pfx = ::llama_tokenize(ctx, params.input_prefix, false, true);
852-
const auto line_inp = ::llama_tokenize(ctx, buffer, false, false);
880+
const auto line_inp = ::llama_tokenize(ctx, user_inp, false, params.conversation);
853881
const auto line_sfx = ::llama_tokenize(ctx, params.input_suffix, false, true);
854882

855883
LOG("input tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, line_inp).c_str());
@@ -864,6 +892,9 @@ int main(int argc, char ** argv) {
864892
output_ss << llama_token_to_piece(ctx, token);
865893
}
866894

895+
// reset assistant message
896+
assistant_ss.str("");
897+
867898
n_remain -= line_inp.size();
868899
LOG("n_remain: %d\n", n_remain);
869900
} else {

examples/server/server.cpp

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2606,17 +2606,9 @@ int main(int argc, char ** argv) {
26062606

26072607
// print sample chat example to make it clear which template is used
26082608
{
2609-
json chat;
2610-
chat.push_back({{"role", "system"}, {"content", "You are a helpful assistant"}});
2611-
chat.push_back({{"role", "user"}, {"content", "Hello"}});
2612-
chat.push_back({{"role", "assistant"}, {"content", "Hi there"}});
2613-
chat.push_back({{"role", "user"}, {"content", "How are you?"}});
2614-
2615-
const std::string chat_example = format_chat(ctx_server.model, params.chat_template, chat);
2616-
26172609
LOG_INFO("chat template", {
2618-
{"chat_example", chat_example},
2619-
{"built_in", params.chat_template.empty()},
2610+
{"chat_example", llama_chat_format_example(ctx_server.model, params.chat_template)},
2611+
{"built_in", params.chat_template.empty()},
26202612
});
26212613
}
26222614

examples/server/utils.hpp

Lines changed: 5 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -118,36 +118,17 @@ static inline void server_log(const char * level, const char * function, int lin
118118

119119
// Format given chat. If tmpl is empty, we take the template from model metadata
120120
inline std::string format_chat(const struct llama_model * model, const std::string & tmpl, const std::vector<json> & messages) {
121-
size_t alloc_size = 0;
122-
// vector holding all allocated string to be passed to llama_chat_apply_template
123-
std::vector<std::string> str(messages.size() * 2);
124-
std::vector<llama_chat_message> chat(messages.size());
121+
std::vector<llama_chat_msg> chat;
125122

126123
for (size_t i = 0; i < messages.size(); ++i) {
127124
const auto & curr_msg = messages[i];
128-
str[i*2 + 0] = json_value(curr_msg, "role", std::string(""));
129-
str[i*2 + 1] = json_value(curr_msg, "content", std::string(""));
130-
alloc_size += str[i*2 + 1].length();
131-
chat[i].role = str[i*2 + 0].c_str();
132-
chat[i].content = str[i*2 + 1].c_str();
125+
std::string role = json_value(curr_msg, "role", std::string(""));
126+
std::string content = json_value(curr_msg, "content", std::string(""));
127+
chat.push_back({role, content});
133128
}
134129

135-
const char * ptr_tmpl = tmpl.empty() ? nullptr : tmpl.c_str();
136-
std::vector<char> buf(alloc_size * 2);
137-
138-
// run the first time to get the total output length
139-
int32_t res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), true, buf.data(), buf.size());
140-
141-
// if it turns out that our buffer is too small, we resize it
142-
if ((size_t) res > buf.size()) {
143-
buf.resize(res);
144-
res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), true, buf.data(), buf.size());
145-
}
146-
147-
const std::string formatted_chat(buf.data(), res);
148-
130+
auto formatted_chat = llama_chat_apply_template(model, tmpl, chat, true);
149131
LOG_VERBOSE("formatted_chat", {{"text", formatted_chat.c_str()}});
150-
151132
return formatted_chat;
152133
}
153134

llama.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18818,10 +18818,10 @@ static int32_t llama_chat_apply_template_internal(
1881818818
if (add_ass) {
1881918819
ss << "<|im_start|>assistant\n";
1882018820
}
18821-
} else if (tmpl == "llama2" || tmpl.find("[INST]") != std::string::npos) {
18821+
} else if (tmpl == "llama2" || tmpl == "mistral" || tmpl.find("[INST]") != std::string::npos) {
1882218822
// llama2 template and its variants
1882318823
// [variant] support system message
18824-
bool support_system_message = tmpl.find("<<SYS>>") != std::string::npos;
18824+
bool support_system_message = tmpl.find("<<SYS>>") != std::string::npos || tmpl == "mistral";
1882518825
// [variant] space before + after response
1882618826
bool space_around_response = tmpl.find("' ' + eos_token") != std::string::npos;
1882718827
// [variant] add BOS inside history

tests/test-chat-template.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <cassert>
88

99
#include "llama.h"
10+
#include "common.h"
1011

1112
int main(void) {
1213
llama_chat_message conversation[] = {
@@ -119,5 +120,24 @@ int main(void) {
119120
std::cout << output << "\n-------------------------\n";
120121
assert(output == expected);
121122
}
123+
124+
// test llama_chat_format_single
125+
std::cout << "\n\n=== llama_chat_format_single ===\n\n";
126+
std::vector<llama_chat_msg> chat2;
127+
chat2.push_back({"system", "You are a helpful assistant"});
128+
chat2.push_back({"user", "Hello"});
129+
chat2.push_back({"assistant", "I am assistant"});
130+
llama_chat_msg new_msg{"user", "How are you"};
131+
132+
auto fmt_single = [&](std::string tmpl) {
133+
auto output = llama_chat_format_single(nullptr, tmpl, chat2, new_msg, true);
134+
std::cout << "fmt_single(" << tmpl << ")\n" << output << "\n-------------------------\n";
135+
return output;
136+
};
137+
assert(fmt_single("chatml") == "<|im_start|>user\nHow are you<|im_end|>\n<|im_start|>assistant\n");
138+
assert(fmt_single("llama2") == "[INST] How are you [/INST]");
139+
assert(fmt_single("gemma") == "<start_of_turn>user\nHow are you<end_of_turn>\n<start_of_turn>model\n");
140+
assert(fmt_single("llama3") == "<|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n");
141+
122142
return 0;
123143
}

0 commit comments

Comments
 (0)