Skip to content

Commit 99112c5

Browse files
ngxsonNexesenex
authored andcommitted
Fix new line issue with chat template, disable template when in-prefix/suffix is set (ggml-org#8203)
* preserve new line llama_chat_format_single * disable chat template if in-prefix/suffix is set * remove redundant change
1 parent ec1ce34 commit 99112c5

File tree

4 files changed

+23
-9
lines changed

4 files changed

+23
-9
lines changed

common/common.cpp

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1039,16 +1039,19 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
10391039
}
10401040
if (arg == "--in-prefix-bos") {
10411041
params.input_prefix_bos = true;
1042+
params.enable_chat_template = false;
10421043
return true;
10431044
}
10441045
if (arg == "--in-prefix") {
10451046
CHECK_ARG
10461047
params.input_prefix = argv[i];
1048+
params.enable_chat_template = false;
10471049
return true;
10481050
}
10491051
if (arg == "--in-suffix") {
10501052
CHECK_ARG
10511053
params.input_suffix = argv[i];
1054+
params.enable_chat_template = false;
10521055
return true;
10531056
}
10541057
if (arg == "--spm-infill") {
@@ -1431,7 +1434,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
14311434
"halt generation at PROMPT, return control in interactive mode\n"
14321435
"can be specified more than once for multiple prompts" });
14331436
options.push_back({ "main", "-sp, --special", "special tokens output enabled (default: %s)", params.special ? "true" : "false" });
1434-
options.push_back({ "main", "-cnv, --conversation", "run in conversation mode (does not print special tokens and suffix/prefix) (default: %s)", params.conversation ? "true" : "false" });
1437+
options.push_back({ "main", "-cnv, --conversation", "run in conversation mode (does not print special tokens and suffix/prefix, use default chat template) (default: %s)", params.conversation ? "true" : "false" });
14351438
options.push_back({ "main infill", "-i, --interactive", "run in interactive mode (default: %s)", params.interactive ? "true" : "false" });
14361439
options.push_back({ "main infill", "-if, --interactive-first", "run in interactive mode and wait for input right away (default: %s)", params.interactive_first ? "true" : "false" });
14371440
options.push_back({ "main infill", "-mli, --multiline-input", "allows you to write or paste multiple lines without ending each in '\\'" });
@@ -2693,12 +2696,19 @@ std::string llama_chat_format_single(const struct llama_model * model,
26932696
const std::vector<llama_chat_msg> & past_msg,
26942697
const llama_chat_msg & new_msg,
26952698
bool add_ass) {
2699+
std::ostringstream ss;
26962700
auto fmt_past_msg = llama_chat_apply_template(model, tmpl, past_msg, false);
26972701
std::vector<llama_chat_msg> chat_new(past_msg);
2702+
// if the past_msg ends with a newline, we must preserve it in the formatted version
2703+
if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') {
2704+
ss << "\n";
2705+
};
2706+
// format chat with new_msg
26982707
chat_new.push_back(new_msg);
26992708
auto fmt_new_msg = llama_chat_apply_template(model, tmpl, chat_new, add_ass);
2700-
auto formatted = fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size());
2701-
return formatted;
2709+
// get the diff part
2710+
ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size());
2711+
return ss.str();
27022712
}
27032713

27042714
std::string llama_chat_format_example(const struct llama_model * model,

common/common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,7 @@ struct gpt_params {
217217
std::string public_path = "";
218218
std::string chat_template = "";
219219
std::string system_prompt = "";
220+
bool enable_chat_template = true;
220221

221222
std::vector<std::string> api_keys;
222223

examples/main/main.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ int main(int argc, char ** argv) {
262262
std::vector<llama_token> embd_inp;
263263

264264
{
265-
auto prompt = params.conversation
265+
auto prompt = (params.conversation && params.enable_chat_template)
266266
? chat_add_and_format(model, chat_msgs, "system", params.prompt) // format the system prompt in conversation mode
267267
: params.prompt;
268268
if (params.interactive_first || !params.prompt.empty() || session_tokens.empty()) {
@@ -811,7 +811,9 @@ int main(int argc, char ** argv) {
811811
is_antiprompt = true;
812812
}
813813

814-
chat_add_and_format(model, chat_msgs, "assistant", assistant_ss.str());
814+
if (params.enable_chat_template) {
815+
chat_add_and_format(model, chat_msgs, "assistant", assistant_ss.str());
816+
}
815817
is_interacting = true;
816818
printf("\n");
817819
}
@@ -873,12 +875,13 @@ int main(int argc, char ** argv) {
873875
string_process_escapes(buffer);
874876
}
875877

876-
std::string user_inp = params.conversation
878+
bool format_chat = params.conversation && params.enable_chat_template;
879+
std::string user_inp = format_chat
877880
? chat_add_and_format(model, chat_msgs, "user", std::move(buffer))
878881
: std::move(buffer);
879882
// TODO: one inconvenient of current chat template implementation is that we can't distinguish between user input and special tokens (prefix/postfix)
880883
const auto line_pfx = ::llama_tokenize(ctx, params.input_prefix, false, true);
881-
const auto line_inp = ::llama_tokenize(ctx, user_inp, false, params.conversation);
884+
const auto line_inp = ::llama_tokenize(ctx, user_inp, false, format_chat);
882885
const auto line_sfx = ::llama_tokenize(ctx, params.input_suffix, false, true);
883886

884887
LOG("input tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, line_inp).c_str());

tests/test-chat-template.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,9 +142,9 @@ int main(void) {
142142
std::cout << "fmt_single(" << tmpl << ")\n" << output << "\n-------------------------\n";
143143
return output;
144144
};
145-
assert(fmt_single("chatml") == "<|im_start|>user\nHow are you<|im_end|>\n<|im_start|>assistant\n");
145+
assert(fmt_single("chatml") == "\n<|im_start|>user\nHow are you<|im_end|>\n<|im_start|>assistant\n");
146146
assert(fmt_single("llama2") == "[INST] How are you [/INST]");
147-
assert(fmt_single("gemma") == "<start_of_turn>user\nHow are you<end_of_turn>\n<start_of_turn>model\n");
147+
assert(fmt_single("gemma") == "\n<start_of_turn>user\nHow are you<end_of_turn>\n<start_of_turn>model\n");
148148
assert(fmt_single("llama3") == "<|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n");
149149

150150
return 0;

0 commit comments

Comments
 (0)