5
5
# LICENSE file in the root directory of this source tree.
6
6
7
7
# Standard
8
- from typing import List , Optional
8
+ from typing import Dict , List , Optional
9
9
import json
10
10
import os
11
11
12
12
# Third Party
13
+ import jinja2
13
14
from tokenizers import Tokenizer
14
15
15
16
# Local
@@ -37,6 +38,9 @@ def __init__(self, file_path: str):
37
38
# Load the tokenizer itself
38
39
self ._tokenizer = Tokenizer .from_file (tokenizer_path )
39
40
41
+ # Load the chat template if we have a config path
42
+ self ._chat_template : Optional [jinja2 .Template ] = None
43
+
40
44
# If available, parse bos/eos tokens from the tokenizer config
41
45
self ._bos_id , self ._eos_id = None , None
42
46
if tokenizer_config_path is not None :
@@ -48,6 +52,8 @@ def __init__(self, file_path: str):
48
52
self ._bos_id = self ._tokenizer .token_to_id (bos_token )
49
53
if eos_token is not None :
50
54
self ._eos_id = self ._tokenizer .token_to_id (eos_token )
55
+ if chat_template_str := tok_config .get ("chat_template" ):
56
+ self ._chat_template = jinja2 .Template (chat_template_str )
51
57
52
58
# If no eos/bos tokens found, go looking for them!
53
59
if None in [self ._bos_id , self ._eos_id ]:
@@ -70,6 +76,8 @@ def _look_for_special_token(added_tokens: dict, search_strs: List[str]) -> Optio
70
76
if len (candidate_toks ) == 1 :
71
77
return candidate_toks [0 ]["id" ]
72
78
79
+ ## Interface ##
80
+
73
81
def encode (
74
82
self ,
75
83
s : str ,
@@ -90,3 +98,21 @@ def bos_id(self) -> int:
90
98
91
99
def eos_id (self ) -> int :
92
100
return self ._eos_id
101
+
102
+ ## Additional Public Methods ##
103
+
104
+ def has_chat_template (self ) -> bool :
105
+ return bool (self ._chat_template )
106
+
107
+ def apply_chat_template (
108
+ self ,
109
+ dialog : List [Dict [str , str ]],
110
+ add_generation_prompt : bool = False ,
111
+ ) -> str :
112
+ """If configured with a chat template, apply it to the list of messages
113
+ """
114
+ if not self ._chat_template :
115
+ raise ValueError ("No chat template configured!" )
116
+ return self ._chat_template .render (
117
+ messages = dialog , add_generation_prompt = add_generation_prompt
118
+ )
0 commit comments