Skip to content

Commit 68db765

Browse files
committed
llama : update llama_chat_apply_template
ggml-ci
1 parent 22b31cd commit 68db765

File tree

8 files changed

+31
-34
lines changed

8 files changed

+31
-34
lines changed

common/common.cpp

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1648,7 +1648,7 @@ std::string common_get_builtin_chat_template(const struct llama_model * model) {
16481648

16491649
bool common_chat_verify_template(const std::string & tmpl) {
16501650
llama_chat_message chat[] = {{"user", "test"}};
1651-
int res = llama_chat_apply_template(nullptr, tmpl.c_str(), chat, 1, true, nullptr, 0);
1651+
const int res = llama_chat_apply_template(tmpl.c_str(), chat, 1, true, nullptr, 0);
16521652
return res >= 0;
16531653
}
16541654

@@ -1659,35 +1659,34 @@ std::string common_chat_apply_template(const struct llama_model * model,
16591659
int alloc_size = 0;
16601660
bool fallback = false; // indicate if we must fallback to default chatml
16611661
std::vector<llama_chat_message> chat;
1662-
for (auto & msg : msgs) {
1662+
for (const auto & msg : msgs) {
16631663
chat.push_back({msg.role.c_str(), msg.content.c_str()});
16641664
alloc_size += (msg.role.size() + msg.content.size()) * 1.25;
16651665
}
16661666

1667-
const char * ptr_tmpl = tmpl.empty() ? nullptr : tmpl.c_str();
1667+
const char * ptr_tmpl = tmpl.empty() ? llama_model_chat_template(model) : tmpl.c_str();
16681668
std::vector<char> buf(alloc_size);
16691669

16701670
// run the first time to get the total output length
1671-
int32_t res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size());
1671+
int32_t res = llama_chat_apply_template(ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size());
16721672

16731673
// error: chat template is not supported
16741674
if (res < 0) {
16751675
if (ptr_tmpl != nullptr) {
16761676
// if the custom "tmpl" is not supported, we throw an error
16771677
// this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template()
16781678
throw std::runtime_error("this custom template is not supported");
1679-
} else {
1680-
// If the built-in template is not supported, we default to chatml
1681-
res = llama_chat_apply_template(nullptr, "chatml", chat.data(), chat.size(), add_ass, buf.data(), buf.size());
1682-
fallback = true;
16831679
}
1680+
1681+
// If the built-in template is not supported, we default to chatml
1682+
res = llama_chat_apply_template("chatml", chat.data(), chat.size(), add_ass, buf.data(), buf.size());
1683+
fallback = true;
16841684
}
16851685

16861686
// if it turns out that our buffer is too small, we resize it
16871687
if ((size_t) res > buf.size()) {
16881688
buf.resize(res);
16891689
res = llama_chat_apply_template(
1690-
fallback ? nullptr : model,
16911690
fallback ? "chatml" : ptr_tmpl,
16921691
chat.data(), chat.size(), add_ass, buf.data(), buf.size());
16931692
}

examples/run/run.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -713,11 +713,11 @@ static void add_message(const char * role, const std::string & text, LlamaData &
713713
// Function to apply the chat template and resize `formatted` if needed
714714
static int apply_chat_template(LlamaData & llama_data, const bool append) {
715715
int result = llama_chat_apply_template(
716-
llama_data.model.get(), nullptr, llama_data.messages.data(), llama_data.messages.size(), append,
716+
llama_model_chat_template(llama_data.model.get()), llama_data.messages.data(), llama_data.messages.size(), append,
717717
append ? llama_data.fmtted.data() : nullptr, append ? llama_data.fmtted.size() : 0);
718718
if (append && result > static_cast<int>(llama_data.fmtted.size())) {
719719
llama_data.fmtted.resize(result);
720-
result = llama_chat_apply_template(llama_data.model.get(), nullptr, llama_data.messages.data(),
720+
result = llama_chat_apply_template(llama_model_chat_template(llama_data.model.get()), llama_data.messages.data(),
721721
llama_data.messages.size(), append, llama_data.fmtted.data(),
722722
llama_data.fmtted.size());
723723
}

examples/server/server.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1740,7 +1740,8 @@ struct server_context {
17401740

17411741
bool validate_builtin_chat_template() const {
17421742
llama_chat_message chat[] = {{"user", "test"}};
1743-
int32_t chat_res = llama_chat_apply_template(model, nullptr, chat, 1, true, nullptr, 0);
1743+
const char * tmpl = llama_model_chat_template(model);
1744+
const int32_t chat_res = llama_chat_apply_template(tmpl, chat, 1, true, nullptr, 0);
17441745
return chat_res > 0;
17451746
}
17461747

examples/simple-chat/simple-chat.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -161,12 +161,14 @@ int main(int argc, char ** argv) {
161161
break;
162162
}
163163

164+
const char * tmpl = llama_model_chat_template(model);
165+
164166
// add the user input to the message list and format it
165167
messages.push_back({"user", strdup(user.c_str())});
166-
int new_len = llama_chat_apply_template(model, nullptr, messages.data(), messages.size(), true, formatted.data(), formatted.size());
168+
int new_len = llama_chat_apply_template(tmpl, messages.data(), messages.size(), true, formatted.data(), formatted.size());
167169
if (new_len > (int)formatted.size()) {
168170
formatted.resize(new_len);
169-
new_len = llama_chat_apply_template(model, nullptr, messages.data(), messages.size(), true, formatted.data(), formatted.size());
171+
new_len = llama_chat_apply_template(tmpl, messages.data(), messages.size(), true, formatted.data(), formatted.size());
170172
}
171173
if (new_len < 0) {
172174
fprintf(stderr, "failed to apply the chat template\n");
@@ -183,7 +185,7 @@ int main(int argc, char ** argv) {
183185

184186
// add the response to the messages
185187
messages.push_back({"assistant", strdup(response.c_str())});
186-
prev_len = llama_chat_apply_template(model, nullptr, messages.data(), messages.size(), false, nullptr, 0);
188+
prev_len = llama_chat_apply_template(tmpl, messages.data(), messages.size(), false, nullptr, 0);
187189
if (prev_len < 0) {
188190
fprintf(stderr, "failed to apply the chat template\n");
189191
return 1;

include/llama.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -489,6 +489,9 @@ extern "C" {
489489
// Returns the total size of all the tensors in the model in bytes
490490
LLAMA_API uint64_t llama_model_size(const struct llama_model * model);
491491

492+
// Get the default chat template. Returns nullptr if not available
493+
LLAMA_API const char * llama_model_chat_template(const struct llama_model * model);
494+
492495
// Returns the total number of parameters in the model
493496
LLAMA_API uint64_t llama_model_n_params(const struct llama_model * model);
494497

@@ -1009,9 +1012,7 @@ extern "C" {
10091012
/// @param buf A buffer to hold the output formatted prompt. The recommended alloc size is 2 * (total number of characters of all messages)
10101013
/// @param length The size of the allocated buffer
10111014
/// @return The total number of bytes of the formatted prompt. If is it larger than the size of buffer, you may need to re-alloc it and then re-apply the template.
1012-
/// TODO: change to llama_vocab
10131015
LLAMA_API int32_t llama_chat_apply_template(
1014-
const struct llama_model * model,
10151016
const char * tmpl,
10161017
const struct llama_chat_message * chat,
10171018
size_t n_msg,

src/llama-model.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3839,6 +3839,15 @@ uint64_t llama_model_size(const struct llama_model * model) {
38393839
return model->size();
38403840
}
38413841

3842+
const char * llama_model_chat_template(const struct llama_model * model) {
3843+
const auto & it = model->gguf_kv.find("tokenizer.chat_template");
3844+
if (it == model->gguf_kv.end()) {
3845+
return nullptr;
3846+
}
3847+
3848+
return it->second.c_str();
3849+
}
3850+
38423851
uint64_t llama_model_n_params(const struct llama_model * model) {
38433852
return model->n_elements();
38443853
}

src/llama.cpp

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9834,27 +9834,13 @@ int32_t llama_decode(
98349834
//
98359835

98369836
int32_t llama_chat_apply_template(
9837-
const struct llama_model * model,
98389837
const char * tmpl,
98399838
const struct llama_chat_message * chat,
98409839
size_t n_msg,
98419840
bool add_ass,
98429841
char * buf,
98439842
int32_t length) {
9844-
std::string curr_tmpl(tmpl == nullptr ? "" : tmpl);
9845-
if (tmpl == nullptr) {
9846-
GGML_ASSERT(model != nullptr);
9847-
9848-
// load template from model, if available
9849-
const auto & it = model->gguf_kv.find("tokenizer.chat_template");
9850-
if (it != model->gguf_kv.end() && it->second.size() > 0) {
9851-
curr_tmpl = it->second;
9852-
}
9853-
else {
9854-
// worst case: there is no information about template, we will use chatml by default
9855-
curr_tmpl = "chatml"; // see llm_chat_apply_template
9856-
}
9857-
}
9843+
const std::string curr_tmpl(tmpl == nullptr ? "chatml" : tmpl);
98589844

98599845
// format the chat to string
98609846
std::vector<const llama_chat_message *> chat_vec;

tests/test-chat-template.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,15 +157,14 @@ int main(void) {
157157
}
158158

159159
// test invalid chat template
160-
res = llama_chat_apply_template(nullptr, "INVALID TEMPLATE", conversation, message_count, true, formatted_chat.data(), formatted_chat.size());
160+
res = llama_chat_apply_template("INVALID TEMPLATE", conversation, message_count, true, formatted_chat.data(), formatted_chat.size());
161161
assert(res < 0);
162162

163163
for (size_t i = 0; i < templates.size(); i++) {
164164
std::string custom_template = templates[i];
165165
std::string expected = expected_output[i];
166166
formatted_chat.resize(1024);
167167
res = llama_chat_apply_template(
168-
nullptr,
169168
custom_template.c_str(),
170169
conversation,
171170
message_count,

0 commit comments

Comments
 (0)