Skip to content

Commit 46cf1c7

Browse files
larryliu0820facebook-github-bot
authored andcommitted
Add Tiktoken in python (#2986)
Summary: Tiktoken by OpenAI is a popular tokenizer. Pull Request resolved: #2986 Reviewed By: lucylq Differential Revision: D56004355 Pulled By: larryliu0820 fbshipit-source-id: 5656eba6fc6e550fc1d7356162da1d1897e43e78
1 parent 76d8513 commit 46cf1c7

File tree

2 files changed

+235
-1
lines changed

2 files changed

+235
-1
lines changed

examples/models/llama2/install_requirements.sh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@ pip install snakeviz sentencepiece
1111
pip install torchao==0.1
1212

1313
# Install lm-eval for Model Evaluation with lm-evalution-harness
14-
pip install lm-eval
14+
# Install tiktoken for tokenizer
15+
pip install lm-eval tiktoken blobfile
1516

1617
# Call the install helper for further setup
1718
python examples/models/llama2/install_requirement_helper.py
Lines changed: 233 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,233 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import os
8+
from logging import getLogger
9+
from pathlib import Path
10+
from typing import (
11+
AbstractSet,
12+
cast,
13+
Collection,
14+
Dict,
15+
Iterator,
16+
List,
17+
Literal,
18+
Sequence,
19+
TypedDict,
20+
Union,
21+
)
22+
23+
import tiktoken
24+
from tiktoken.load import load_tiktoken_bpe
25+
26+
27+
logger = getLogger(__name__)
28+
29+
30+
Role = Literal["system", "user", "assistant"]
31+
32+
33+
class Message(TypedDict):
34+
role: Role
35+
content: str
36+
37+
38+
Dialog = Sequence[Message]
39+
40+
41+
class Tokenizer:
42+
"""
43+
tokenizing and encoding/decoding text using the Tiktoken tokenizer.
44+
"""
45+
46+
special_tokens: Dict[str, int]
47+
48+
num_reserved_special_tokens = 256
49+
50+
pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # noqa: E501
51+
52+
def __init__(self, model_path: str):
53+
"""
54+
Initializes the Tokenizer with a Tiktoken model.
55+
56+
Args:
57+
model_path (str): The path to the Tiktoken model file.
58+
"""
59+
# reload tokenizer
60+
assert os.path.isfile(model_path), model_path
61+
62+
mergeable_ranks = load_tiktoken_bpe(model_path)
63+
num_base_tokens = len(mergeable_ranks)
64+
special_tokens = [
65+
"<|begin_of_text|>",
66+
"<|end_of_text|>",
67+
"<|reserved_special_token_0|>",
68+
"<|reserved_special_token_1|>",
69+
"<|reserved_special_token_2|>",
70+
"<|reserved_special_token_3|>",
71+
"<|start_header_id|>",
72+
"<|end_header_id|>",
73+
"<|reserved_special_token_4|>",
74+
"<|eot_id|>", # end of turn
75+
] + [
76+
f"<|reserved_special_token_{i}|>"
77+
for i in range(5, self.num_reserved_special_tokens - 5)
78+
]
79+
self.special_tokens = {
80+
token: num_base_tokens + i for i, token in enumerate(special_tokens)
81+
}
82+
self.model = tiktoken.Encoding(
83+
name=Path(model_path).name,
84+
pat_str=self.pat_str,
85+
mergeable_ranks=mergeable_ranks,
86+
special_tokens=self.special_tokens,
87+
)
88+
logger.info(f"Reloaded SentencePiece model from {model_path}")
89+
90+
# BOS / EOS token IDs
91+
self.n_words: int = self.model.n_vocab
92+
self.bos_id: int = self.special_tokens["<|begin_of_text|>"]
93+
self.eos_id: int = self.special_tokens["<|end_of_text|>"]
94+
self.pad_id: int = -1
95+
self.stop_tokens = {
96+
self.special_tokens["<|end_of_text|>"],
97+
self.special_tokens["<|eot_id|>"],
98+
}
99+
logger.info(
100+
f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}"
101+
)
102+
103+
def encode(
104+
self,
105+
s: str,
106+
*,
107+
bos: bool,
108+
eos: bool,
109+
allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), # noqa B006
110+
disallowed_special: Union[Literal["all"], Collection[str]] = (),
111+
) -> List[int]:
112+
"""
113+
Encodes a string into a list of token IDs.
114+
115+
Args:
116+
s (str): The input string to be encoded.
117+
bos (bool): Whether to prepend the beginning-of-sequence token.
118+
eos (bool): Whether to append the end-of-sequence token.
119+
allowed_tokens ("all"|set[str]): allowed special tokens in string
120+
disallowed_tokens ("all"|set[str]): special tokens that raise an error when in string
121+
122+
Returns:
123+
list[int]: A list of token IDs.
124+
125+
By default, setting disallowed_special=() encodes a string by ignoring
126+
special tokens. Specifically:
127+
- Setting `disallowed_special` to () will cause all text corresponding
128+
to special tokens to be encoded as natural text (insteading of raising
129+
an error).
130+
- Setting `allowed_special` to "all" will treat all text corresponding
131+
to special tokens to be encoded as special tokens.
132+
"""
133+
assert type(s) is str
134+
135+
# The tiktoken tokenizer can handle <=400k chars without
136+
# pyo3_runtime.PanicException (may go beyond 400k)
137+
TIKTOKEN_MAX_ENCODE_CHARS = 400_000
138+
139+
# https://github.com/openai/tiktoken/issues/195
140+
# Here we iterate over subsequences and split if we exceed the limit
141+
# of max consecutive non-whitespace or whitespace characters.
142+
MAX_NO_WHITESPACES_CHARS = 25_000
143+
144+
substrs = (
145+
substr
146+
for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS)
147+
for substr in self._split_whitespaces_or_nonwhitespaces(
148+
s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS
149+
)
150+
)
151+
t: List[int] = []
152+
for substr in substrs:
153+
t.extend(
154+
self.model.encode(
155+
substr,
156+
allowed_special=allowed_special,
157+
disallowed_special=disallowed_special,
158+
)
159+
)
160+
if bos:
161+
t.insert(0, self.bos_id)
162+
if eos:
163+
t.append(self.eos_id)
164+
return t
165+
166+
def decode(self, t: Sequence[int]) -> str:
167+
"""
168+
Decodes a list of token IDs into a string.
169+
170+
Args:
171+
t (List[int]): The list of token IDs to be decoded.
172+
173+
Returns:
174+
str: The decoded string.
175+
"""
176+
# typecast is safe here, Tiktoken doesn't do anything list-related with the sequence.
177+
return self.model.decode(cast(List[int], t))
178+
179+
@staticmethod
180+
def _split_whitespaces_or_nonwhitespaces(
181+
s: str, max_consecutive_slice_len: int
182+
) -> Iterator[str]:
183+
"""
184+
Split the string `s` so that each substring contains no more than `max_consecutive_slice_len`
185+
consecutive whitespaces or consecutive non-whitespaces
186+
"""
187+
current_slice_len = 0
188+
current_slice_is_space = s[0].isspace() if len(s) > 0 else False
189+
slice_start = 0
190+
191+
for i in range(len(s)):
192+
is_now_space = s[i].isspace()
193+
194+
if current_slice_is_space ^ is_now_space:
195+
current_slice_len = 1
196+
current_slice_is_space = is_now_space
197+
else:
198+
current_slice_len += 1
199+
if current_slice_len > max_consecutive_slice_len:
200+
yield s[slice_start:i]
201+
slice_start = i
202+
current_slice_len = 1
203+
yield s[slice_start:]
204+
205+
206+
class ChatFormat:
207+
def __init__(self, tokenizer: Tokenizer):
208+
self.tokenizer = tokenizer
209+
210+
def encode_header(self, message: Message) -> List[int]:
211+
tokens = []
212+
tokens.append(self.tokenizer.special_tokens["<|start_header_id|>"])
213+
tokens.extend(self.tokenizer.encode(message["role"], bos=False, eos=False))
214+
tokens.append(self.tokenizer.special_tokens["<|end_header_id|>"])
215+
tokens.extend(self.tokenizer.encode("\n\n", bos=False, eos=False))
216+
return tokens
217+
218+
def encode_message(self, message: Message) -> List[int]:
219+
tokens = self.encode_header(message)
220+
tokens.extend(
221+
self.tokenizer.encode(message["content"].strip(), bos=False, eos=False)
222+
)
223+
tokens.append(self.tokenizer.special_tokens["<|eot_id|>"])
224+
return tokens
225+
226+
def encode_dialog_prompt(self, dialog: Dialog) -> List[int]:
227+
tokens = []
228+
tokens.append(self.tokenizer.special_tokens["<|begin_of_text|>"])
229+
for message in dialog:
230+
tokens.extend(self.encode_message(message))
231+
# Add the start of an assistant message for the model to complete
232+
tokens.extend(self.encode_header({"role": "assistant", "content": ""}))
233+
return tokens

0 commit comments

Comments
 (0)