Skip to content

Commit d624ed3

Browse files
committed
fix: Remove trailing \n from llama3 <|eot_id|>
There's inconsistency in the documentation on whether or not there should be a \n after <|eot_id|>, but this maintains consistency with previous formatting Branch: GraniteCodeSupport Signed-off-by: Gabe Goodhart <[email protected]>
1 parent 4882899 commit d624ed3

File tree

2 files changed

+12
-23
lines changed

2 files changed

+12
-23
lines changed

tests/test_chat_formatters.py

Lines changed: 11 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -139,44 +139,33 @@ def test_llama2_chat_formatter(messages, expected):
139139
# single user message (no system prompt)
140140
(MSGS_NO_SYS, f"""<|begin_of_text|><|start_header_id|>user<|end_header_id|>
141141
142-
{USER1}<|eot_id|>
143-
"""),
142+
{USER1}<|eot_id|>"""),
144143
# sys, usr
145144
(MSGS_SYS_USR, f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
146145
147-
{SYSTEM_PROMPT}<|eot_id|>
148-
<|start_header_id|>user<|end_header_id|>
146+
{SYSTEM_PROMPT}<|eot_id|><|start_header_id|>user<|end_header_id|>
149147
150-
{USER1}<|eot_id|>
151-
"""),
148+
{USER1}<|eot_id|>"""),
152149
# sys, usr, asst
153150
(MSGS_SYS_USR_ASST, f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
154151
155-
{SYSTEM_PROMPT}<|eot_id|>
156-
<|start_header_id|>user<|end_header_id|>
152+
{SYSTEM_PROMPT}<|eot_id|><|start_header_id|>user<|end_header_id|>
157153
158-
{USER1}<|eot_id|>
159-
<|start_header_id|>assistant<|end_header_id|>
154+
{USER1}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
160155
161-
{ASSISTANT1}<|eot_id|>
162-
"""),
156+
{ASSISTANT1}<|eot_id|>"""),
163157
# sys, usr, asst, usr, asst
164158
(MSGS_MULTI_TURN, f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
165159
166-
{SYSTEM_PROMPT}<|eot_id|>
167-
<|start_header_id|>user<|end_header_id|>
160+
{SYSTEM_PROMPT}<|eot_id|><|start_header_id|>user<|end_header_id|>
168161
169-
{USER1}<|eot_id|>
170-
<|start_header_id|>assistant<|end_header_id|>
162+
{USER1}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
171163
172-
{ASSISTANT1}<|eot_id|>
173-
<|start_header_id|>user<|end_header_id|>
164+
{ASSISTANT1}<|eot_id|><|start_header_id|>user<|end_header_id|>
174165
175-
{USER2}<|eot_id|>
176-
<|start_header_id|>assistant<|end_header_id|>
166+
{USER2}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
177167
178-
{ASSISTANT2}<|eot_id|>
179-
"""),
168+
{ASSISTANT2}<|eot_id|>"""),
180169
]
181170
)
182171
@pytest.mark.parametrize("add_generation_prompt", [True, False])

torchchat/generate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def _encode_message(self, message: _ChatFormatter.MESSAGE_TYPE) -> List[int]:
121121
self.tokenizer.encode(content["text"], bos=False, eos=False)
122122
)
123123

124-
tokens.append(self.tokenizer.special_tokens["<|eot_id|>\n"])
124+
tokens.append(self.tokenizer.special_tokens["<|eot_id|>"])
125125
return tokens
126126

127127
def encode_dialog_prompt(

0 commit comments

Comments
 (0)