Skip to content

Commit c463b47

Browse files
committed
fix(chat): Fix small formatting bugs in llama3 chat formatter
Branch: GraniteCodeSupport Signed-off-by: Gabe Goodhart <[email protected]>
1 parent 678d828 commit c463b47

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

torchchat/generate.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def _encode_message(self, message: _ChatFormatter.MESSAGE_TYPE) -> List[int]:
110110
self.tokenizer.encode(content["text"], bos=False, eos=False)
111111
)
112112

113-
tokens.append(self.tokenizer.special_tokens["<|eot_id|>"])
113+
tokens.append(self.tokenizer.special_tokens["<|eot_id|>\n"])
114114
return tokens
115115

116116
def encode_dialog_prompt(
@@ -123,8 +123,8 @@ def encode_dialog_prompt(
123123
for message in dialog:
124124
tokens.extend(self._encode_message(message))
125125
# Add the start of an assistant message for the model to complete.
126-
if add_generation_prompt:
127-
tokens.extend(self._encode_header("assistant")) # Pass role directly as a string
126+
if add_generation_prompt and dialog and dialog[-1]["role"] != "assistant":
127+
tokens.extend(self._encode_header("assistant")) # Pass role directly as a string
128128
return tokens
129129

130130

0 commit comments

Comments
 (0)