Skip to content

Commit 4dfe6a2

Browse files
committed
feat(tokenizer): Add basic support for jinja2 template rendering for HF tokenizers
This is a much simplified version of the corresponding logic in transformers. I opted for this so that the full transformers dependency is not added here. CITE: https://github.com/huggingface/transformers/blob/main/src/transformers/tokenization_utils_base.py#L1522 Branch: GraniteCodeSupport Signed-off-by: Gabe Goodhart <[email protected]>
1 parent 715a58b commit 4dfe6a2

File tree

1 file changed

+27
-1
lines changed

1 file changed

+27
-1
lines changed

tokenizer/hf_tokenizer.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,12 @@
55
# LICENSE file in the root directory of this source tree.
66

77
# Standard
8-
from typing import List, Optional
8+
from typing import Dict, List, Optional
99
import json
1010
import os
1111

1212
# Third Party
13+
import jinja2
1314
from tokenizers import Tokenizer
1415

1516
# Local
@@ -37,6 +38,9 @@ def __init__(self, file_path: str):
3738
# Load the tokenizer itself
3839
self._tokenizer = Tokenizer.from_file(tokenizer_path)
3940

41+
# Load the chat template if we have a config path
42+
self._chat_template: Optional[jinja2.Template] = None
43+
4044
# If available, parse bos/eos tokens from the tokenizer config
4145
self._bos_id, self._eos_id = None, None
4246
if tokenizer_config_path is not None:
@@ -48,6 +52,8 @@ def __init__(self, file_path: str):
4852
self._bos_id = self._tokenizer.token_to_id(bos_token)
4953
if eos_token is not None:
5054
self._eos_id = self._tokenizer.token_to_id(eos_token)
55+
if chat_template_str := tok_config.get("chat_template"):
56+
self._chat_template = jinja2.Template(chat_template_str)
5157

5258
# If no eos/bos tokens found, go looking for them!
5359
if None in [self._bos_id, self._eos_id]:
@@ -70,6 +76,8 @@ def _look_for_special_token(added_tokens: dict, search_strs: List[str]) -> Optio
7076
if len(candidate_toks) == 1:
7177
return candidate_toks[0]["id"]
7278

79+
## Interface ##
80+
7381
def encode(
7482
self,
7583
s: str,
@@ -90,3 +98,21 @@ def bos_id(self) -> int:
9098

9199
def eos_id(self) -> int:
92100
return self._eos_id
101+
102+
## Additional Public Methods ##
103+
104+
def has_chat_template(self) -> bool:
105+
return bool(self._chat_template)
106+
107+
def apply_chat_template(
108+
self,
109+
dialog: List[Dict[str, str]],
110+
add_generation_prompt: bool = False,
111+
) -> str:
112+
"""If configured with a chat template, apply it to the list of messages
113+
"""
114+
if not self._chat_template:
115+
raise ValueError("No chat template configured!")
116+
return self._chat_template.render(
117+
messages=dialog, add_generation_prompt=add_generation_prompt
118+
)

0 commit comments

Comments
 (0)