Skip to content

Commit 15a275d

Browse files
committed
fix(chat): Add HFTokenizerChatFormatter and use it for HF tokenizers
This will allow the jinja2 templates for HF tokenizers to be applied without needing to hard-code the formatter logic. This will likely need to be duplicated in the embedded code version of chat. Branch: GraniteCodeSupport Signed-off-by: Gabe Goodhart <[email protected]>
1 parent 978f598 commit 15a275d

File tree

1 file changed

+13
-0
lines changed

1 file changed

+13
-0
lines changed

torchchat/generate.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,15 @@ def encode_dialog_prompt(self, dialog) -> List[int]:
125125
return tokens
126126

127127

128+
class HFTokenizerChatFormatter(_ChatFormatter):
129+
"""Chat formatter that uses the built-in formatting capabilities of an HF
130+
tokenizer instance
131+
"""
132+
def encode_dialog_prompt(self, dialog) -> List[int]:
133+
rendered = self.tokenizer.apply_chat_template(dialog, add_generation_prompt=True)
134+
return self.tokenizer.encode(rendered)
135+
136+
128137
@dataclass
129138
class GeneratorArgs:
130139
prompt: Optional[str] = (
@@ -286,6 +295,10 @@ def __init__(
286295
logging.debug(
287296
"Llama3 model detected in chat mode. Using updated sentence schemas"
288297
)
298+
elif self.tokenizer_args.is_hf_tokenizer:
299+
if not self.tokenizer.has_chat_template():
300+
raise ValueError("Tokenizer must have a chat template")
301+
self.chat_formatter = HFTokenizerChatFormatter(self.tokenizer)
289302
else:
290303
self.chat_formatter = Llama2ChatFormatter(self.tokenizer)
291304

0 commit comments

Comments
 (0)