Skip to content

Commit 11d4705

Browse files
MollySophiangxson
andauthored
Rwkv chat template fix (#10001)
* llama: remove useless template matching for rwkv-world Signed-off-by: Molly Sophia <[email protected]> * converter: Add comment about the hack for rwkv models Signed-off-by: Molly Sophia <[email protected]> * Update src/llama.cpp Co-authored-by: Xuan Son Nguyen <[email protected]> --------- Signed-off-by: Molly Sophia <[email protected]> Co-authored-by: Xuan Son Nguyen <[email protected]>
1 parent c421ac0 commit 11d4705

File tree

3 files changed

+3
-5
lines changed

3 files changed

+3
-5
lines changed

convert_hf_to_gguf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2865,6 +2865,7 @@ def set_vocab(self):
28652865
self.gguf_writer.add_token_types(toktypes)
28662866
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False)
28672867
special_vocab.chat_template = "rwkv-world"
2868+
# hack: Add '\n\n' as the EOT token to make it chat normally
28682869
special_vocab._set_special_token("eot", 261)
28692870
special_vocab.add_to_gguf(self.gguf_writer)
28702871

src/llama.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21697,7 +21697,8 @@ static int32_t llama_chat_apply_template_internal(
2169721697
if (add_ass) {
2169821698
ss << "[|assistant|]";
2169921699
}
21700-
} else if (tmpl == "rwkv-world" || tmpl_contains("rwkv-world") || tmpl_contains("'User: ' + message['content'] + '\n\nAssistant:'")) {
21700+
} else if (tmpl == "rwkv-world" || tmpl_contains("rwkv-world")) {
21701+
// this template requires the model to have "\n\n" as EOT token
2170121702
for (auto message : chat) {
2170221703
std::string role(message->role);
2170321704
if (role == "user") {

tests/test-chat-template.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,6 @@ int main(void) {
6565
u8"{% for message in messages %}{% if message['role'] == 'user' %}{{'<用户>' + message['content'].strip() + '<AI>'}}{% else %}{{message['content'].strip()}}{% endif %}{% endfor %}",
6666
// DeepSeek-V2
6767
"{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ 'User: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'] + '\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}",
68-
// RWKV-World
69-
"{% for message in messages %}{% if message['role'] == 'user' %}{{'User: ' + message['content'] + '\n\nAssistant:'}}{% else %}{{message['content'] + '\n\n'}}{% endif %}{% endfor %}",
7068
};
7169
std::vector<std::string> expected_output = {
7270
// teknium/OpenHermes-2.5-Mistral-7B
@@ -111,8 +109,6 @@ int main(void) {
111109
u8"You are a helpful assistant<用户>Hello<AI>Hi there<用户>Who are you<AI>I am an assistant<用户>Another question<AI>",
112110
// DeepSeek-V2
113111
u8"You are a helpful assistant\n\nUser: Hello\n\nAssistant: Hi there<|end▁of▁sentence|>User: Who are you\n\nAssistant: I am an assistant <|end▁of▁sentence|>User: Another question\n\nAssistant:",
114-
// RWKV-World
115-
"You are a helpful assistant\n\nUser: Hello\n\nAssistant:Hi there\n\nUser: Who are you\n\nAssistant: I am an assistant \n\nUser: Another question\n\nAssistant:",
116112
};
117113
std::vector<char> formatted_chat(1024);
118114
int32_t res;

0 commit comments

Comments
 (0)