Skip to content

Commit 9ba5494

Browse files
authored
Python hugging face tokenizer (#8354)
1 parent d729176 commit 9ba5494

File tree

7 files changed

+120
-27
lines changed

7 files changed

+120
-27
lines changed

examples/models/llama/install_requirements.sh

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,12 @@
55
# This source code is licensed under the BSD-style license found in the
66
# LICENSE file in the root directory of this source tree.
77

8+
# Install sentencepiece for llama tokenizer.
9+
# Install tiktoken for tokenizer.
10+
# Install tokenizers for hf .json tokenizer.
811
# Install snakeviz for cProfile flamegraph
9-
# Install sentencepiece for llama tokenizer
10-
pip install snakeviz sentencepiece
11-
12-
# Install lm-eval for Model Evaluation with lm-evalution-harness
13-
# Install tiktoken for tokenizer
14-
pip install lm_eval==0.4.5
15-
pip install tiktoken blobfile
12+
# Install lm-eval for Model Evaluation with lm-evalution-harness.
13+
pip install tiktoken sentencepiece tokenizers snakeviz lm_eval==0.4.5 blobfile
1614

1715
# Call the install helper for further setup
1816
python examples/models/llama/install_requirement_helper.py

examples/models/llama/runner/eager.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def __init__(self, args):
2828
params = json.loads(f.read())
2929
super().__init__(
3030
tokenizer_path=args.tokenizer_path,
31+
tokenizer_config_path=args.tokenizer_config_path,
3132
max_seq_len=args.max_seq_length,
3233
max_batch_size=1,
3334
use_kv_cache=args.use_kv_cache,
@@ -74,6 +75,13 @@ def build_args_parser() -> argparse.ArgumentParser:
7475
help="Have multi-turn chat with the model",
7576
)
7677

78+
parser.add_argument(
79+
"--tokenizer_config_path",
80+
type=str,
81+
default=None,
82+
help="Path to an accompanying tokenizer_config.json, which provides metadata for the main tokenizer.json",
83+
)
84+
7785
return parser
7886

7987

examples/models/llama/runner/generation.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,9 @@ def next_token(logits: torch.Tensor, temperature: float, top_p: float) -> int:
4848
class LlamaRunner(ABC):
4949
def __init__(
5050
self,
51+
*,
5152
tokenizer_path: str,
53+
tokenizer_config_path: Optional[str] = None,
5254
max_seq_len: int,
5355
max_batch_size: int,
5456
use_kv_cache: bool,
@@ -59,19 +61,23 @@ def __init__(
5961
Constructor.
6062
6163
Args:
62-
tokenizer_path: path to tokenizer.model file.
63-
max_seq_len: max length of the output sequence, after which the output will be clipped.
64-
max_batch_size: max batch size.
65-
use_kv_cache: whether to use a KV cache.
66-
vocab_size: number of items in the vocab.
67-
device: device to run the runner on.
64+
tokenizer_path: path to tokenizer.model file.
65+
max_seq_len: max length of the output sequence, after which the output will be clipped.
66+
max_batch_size: max batch size.
67+
use_kv_cache: whether to use a KV cache.
68+
vocab_size: number of items in the vocab.
69+
device: device to run the runner on.
6870
"""
6971
self.max_seq_len = max_seq_len
7072
self.max_batch_size = max_batch_size
7173
self.use_kv_cache = use_kv_cache
72-
self.tokenizer = get_tokenizer(tokenizer_path)
74+
self.tokenizer = get_tokenizer(tokenizer_path, tokenizer_config_path)
7375
self.device = device
74-
assert vocab_size == self.tokenizer.n_words
76+
# For some models like qwen, mismatch is acceptable: https://github.com/QwenLM/Qwen2.5/issues/466#issuecomment-2146759706
77+
if vocab_size != self.tokenizer.n_words:
78+
print(
79+
"Warning - given vocab_size in params is unequal to tokenizer vocab size."
80+
)
7581

7682
@abstractmethod
7783
def forward(

examples/models/llama/runner/native.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def __init__(self, args):
3737
params = json.loads(f.read())
3838
super().__init__(
3939
tokenizer_path=args.tokenizer,
40+
tokenizer_config_path=args.tokenizer_config,
4041
max_seq_len=args.max_len,
4142
max_batch_size=1,
4243
use_kv_cache=args.kv_cache,
@@ -56,6 +57,14 @@ def forward(
5657
)[0]
5758

5859

60+
def validate_args(args) -> None:
61+
if args.tokenizer and args.tokenizer.endswith(".json"):
62+
if not args.tokenizer_config:
63+
raise TypeError(
64+
"Json tokenizers require an accompanying tokenizer config (--tokenizer_config) to be specified."
65+
)
66+
67+
5968
def build_args_parser() -> argparse.ArgumentParser:
6069
# TODO: merge these with build_args_parser from export_llama_lib.
6170
parser = argparse.ArgumentParser()
@@ -85,6 +94,13 @@ def build_args_parser() -> argparse.ArgumentParser:
8594
default=None,
8695
)
8796

97+
parser.add_argument(
98+
"--tokenizer_config",
99+
type=str,
100+
default=None,
101+
help="Path to an accompanying tokenizer_config.json, which provides metadata for the main tokenizer.json",
102+
)
103+
88104
parser.add_argument(
89105
"--prompt",
90106
type=str,
@@ -116,6 +132,7 @@ def build_args_parser() -> argparse.ArgumentParser:
116132
def main() -> None:
117133
parser = build_args_parser()
118134
args = parser.parse_args()
135+
validate_args(args)
119136
runner = NativeLlamaRunner(args)
120137
generated_tokens = runner.text_completion(
121138
prompt=args.prompt,

examples/models/llama3_2_vision/runner/generation.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
class TorchTuneLlamaRunner(LlamaRunner):
1414
def __init__(
1515
self,
16+
*,
1617
tokenizer_path: str,
1718
max_seq_len: int,
1819
max_batch_size: int,
@@ -21,12 +22,12 @@ def __init__(
2122
device: str = "cpu",
2223
):
2324
super().__init__(
24-
tokenizer_path,
25-
max_seq_len,
26-
max_batch_size,
27-
use_kv_cache,
28-
vocab_size,
29-
device,
25+
tokenizer_path=tokenizer_path,
26+
max_seq_len=max_seq_len,
27+
max_batch_size=max_batch_size,
28+
use_kv_cache=use_kv_cache,
29+
vocab_size=vocab_size,
30+
device=device,
3031
)
3132

3233
self.causal_mask = torch.tril(
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import json
8+
import os
9+
from typing import List, Optional
10+
11+
from tokenizers import Tokenizer
12+
13+
14+
class HuggingFaceTokenizer:
15+
"""
16+
Tokenizing and encoding/decoding text using the Hugging face tokenizer.
17+
"""
18+
19+
def __init__(self, model_path: str, config_path: Optional[str] = None):
20+
"""
21+
Initializes the Tokenizer with a tokenizer.json from HuggingFace.
22+
23+
Args:
24+
model_path (str): The path to the Tiktoken model file.
25+
"""
26+
assert os.path.isfile(model_path), model_path
27+
28+
self.model = tokenizer = Tokenizer.from_file(model_path)
29+
30+
self.n_words: int = tokenizer.get_vocab_size()
31+
if config_path:
32+
with open(config_path) as f:
33+
tokenizer_config = json.load(f)
34+
self.bos_id = (
35+
self.model.token_to_id(tokenizer_config["bos_token"])
36+
if tokenizer_config["bos_token"]
37+
else None
38+
)
39+
self.eos_id = self.model.token_to_id(tokenizer_config["eos_token"])
40+
else: # Fallback guess.
41+
self.bos_id = self.model.token_to_id("<|begin_of_text|>")
42+
self.eos_id = self.model.token_to_id("<|endoftext|>")
43+
44+
self.stop_tokens = [
45+
self.eos_id,
46+
]
47+
48+
def encode(self, s: str, *, bos: bool, eos: bool) -> List[int]:
49+
assert type(s) is str
50+
return self.model.encode(s).ids
51+
52+
def decode(self, t: List[int]) -> str:
53+
return self.model.decode(t)
54+
55+
def decode_token(self, t: int) -> str:
56+
return self.model.decode([t])

extension/llm/tokenizer/utils.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,23 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
from typing import Optional
8+
79
from executorch.examples.models.llama.tokenizer.tiktoken import Tokenizer as Tiktoken
810
from executorch.extension.llm.tokenizer.tokenizer import (
911
Tokenizer as SentencePieceTokenizer,
1012
)
1113

1214

13-
def get_tokenizer(tokenizer_path):
14-
try:
15-
tokenizer = SentencePieceTokenizer(model_path=str(tokenizer_path))
16-
except Exception:
17-
print("Using Tiktokenizer")
18-
tokenizer = Tiktoken(model_path=str(tokenizer_path))
15+
def get_tokenizer(tokenizer_path: str, tokenizer_config_path: Optional[str] = None):
16+
if tokenizer_path.endswith(".json"):
17+
from executorch.extension.llm.tokenizer.hf_tokenizer import HuggingFaceTokenizer
18+
19+
tokenizer = HuggingFaceTokenizer(tokenizer_path, tokenizer_config_path)
20+
else:
21+
try:
22+
tokenizer = SentencePieceTokenizer(model_path=str(tokenizer_path))
23+
except Exception:
24+
print("Using Tiktokenizer")
25+
tokenizer = Tiktoken(model_path=str(tokenizer_path))
1926
return tokenizer

0 commit comments

Comments
 (0)