Skip to content

Commit 0f21a6a

Browse files
ngxsonNexesenex
authored andcommitted
Add chatml fallback for cpp llama_chat_apply_template (ggml-org#8160)
* add chatml fallback for cpp `llama_chat_apply_template` * remove redundant code
1 parent 8e4ec1c commit 0f21a6a

File tree

2 files changed

+20
-1
lines changed

2 files changed

+20
-1
lines changed

common/common.cpp

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2643,6 +2643,7 @@ std::string llama_chat_apply_template(const struct llama_model * model,
26432643
const std::vector<llama_chat_msg> & msgs,
26442644
bool add_ass) {
26452645
int alloc_size = 0;
2646+
bool fallback = false; // indicate if we must fallback to default chatml
26462647
std::vector<llama_chat_message> chat;
26472648
for (auto & msg : msgs) {
26482649
chat.push_back({msg.role.c_str(), msg.content.c_str()});
@@ -2655,10 +2656,26 @@ std::string llama_chat_apply_template(const struct llama_model * model,
26552656
// run the first time to get the total output length
26562657
int32_t res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size());
26572658

2659+
// error: chat template is not supported
2660+
if (res < 0) {
2661+
if (ptr_tmpl != nullptr) {
2662+
// if the custom "tmpl" is not supported, we throw an error
2663+
// this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template()
2664+
throw std::runtime_error("this custom template is not supported");
2665+
} else {
2666+
// If the built-in template is not supported, we default to chatml
2667+
res = llama_chat_apply_template(nullptr, "chatml", chat.data(), chat.size(), add_ass, buf.data(), buf.size());
2668+
fallback = true;
2669+
}
2670+
}
2671+
26582672
// if it turns out that our buffer is too small, we resize it
26592673
if ((size_t) res > buf.size()) {
26602674
buf.resize(res);
2661-
res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size());
2675+
res = llama_chat_apply_template(
2676+
fallback ? nullptr : model,
2677+
fallback ? "chatml" : ptr_tmpl,
2678+
chat.data(), chat.size(), add_ass, buf.data(), buf.size());
26622679
}
26632680

26642681
std::string formatted_chat(buf.data(), res);

common/common.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -397,6 +397,8 @@ struct llama_chat_msg {
397397
bool llama_chat_verify_template(const std::string & tmpl);
398398

399399
// CPP wrapper for llama_chat_apply_template
400+
// If the built-in template is not supported, we default to chatml
401+
// If the custom "tmpl" is not supported, we throw an error
400402
std::string llama_chat_apply_template(const struct llama_model * model,
401403
const std::string & tmpl,
402404
const std::vector<llama_chat_msg> & chat,

0 commit comments

Comments
 (0)