Skip to content

Commit 70fd1fe

Browse files
committed
Qwen runs with HF tokenizer
1 parent 88b3394 commit 70fd1fe

File tree

6 files changed

+59
-39
lines changed

6 files changed

+59
-39
lines changed

examples/models/llama/install_requirements.sh

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,12 @@
55
# This source code is licensed under the BSD-style license found in the
66
# LICENSE file in the root directory of this source tree.
77

8+
# Install sentencepiece for llama tokenizer.
9+
# Install tiktoken for tokenizer.
10+
# Install tokenizers for hf .json tokenizer.
811
# Install snakeviz for cProfile flamegraph
9-
# Install sentencepiece for llama tokenizer
10-
pip install snakeviz sentencepiece
11-
12-
# Install lm-eval for Model Evaluation with lm-evalution-harness
13-
# Install tiktoken for tokenizer
14-
pip install lm_eval==0.4.5
15-
pip install tiktoken blobfile
12+
# Install lm-eval for Model Evaluation with lm-evalution-harness.
13+
pip install tiktoken sentencepiece tokenizers snakeviz lm_eval==0.4.5 blobfile
1614

1715
# Call the install helper for further setup
1816
python examples/models/llama/install_requirement_helper.py

examples/models/llama/runner/generation.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,9 @@ def next_token(logits: torch.Tensor, temperature: float, top_p: float) -> int:
4848
class LlamaRunner(ABC):
4949
def __init__(
5050
self,
51+
*,
5152
tokenizer_path: str,
53+
tokenizer_config_path: Optional[str] = None,
5254
max_seq_len: int,
5355
max_batch_size: int,
5456
use_kv_cache: bool,
@@ -59,20 +61,23 @@ def __init__(
5961
Constructor.
6062
6163
Args:
62-
tokenizer_path: path to tokenizer.model file.
63-
max_seq_len: max length of the output sequence, after which the output will be clipped.
64-
max_batch_size: max batch size.
65-
use_kv_cache: whether to use a KV cache.
66-
vocab_size: number of items in the vocab.
67-
device: device to run the runner on.
64+
tokenizer_path: path to tokenizer.model file.
65+
max_seq_len: max length of the output sequence, after which the output will be clipped.
66+
max_batch_size: max batch size.
67+
use_kv_cache: whether to use a KV cache.
68+
vocab_size: number of items in the vocab.
69+
device: device to run the runner on.
6870
"""
6971
self.max_seq_len = max_seq_len
7072
self.max_batch_size = max_batch_size
7173
self.use_kv_cache = use_kv_cache
72-
self.tokenizer = get_tokenizer(tokenizer_path)
74+
self.tokenizer = get_tokenizer(tokenizer_path, tokenizer_config_path)
7375
self.device = device
74-
# For qwen anything above 151646 is "useless": https://github.com/QwenLM/Qwen2.5/issues/466#issuecomment-2146759706
75-
# assert vocab_size == self.tokenizer.n_words
76+
# For some models like qwen, mismatch is acceptable: https://github.com/QwenLM/Qwen2.5/issues/466#issuecomment-2146759706
77+
if vocab_size != self.tokenizer.n_words:
78+
print(
79+
"Warning - given vocab_size in params is unequal to tokenizer vocab size."
80+
)
7681

7782
@abstractmethod
7883
def forward(
@@ -102,8 +107,7 @@ def generate( # noqa: C901
102107
)
103108

104109
current_token = next_token(logits, temperature, top_p)
105-
# print(f"{self.tokenizer.decode_token(current_token)}", end="", flush=True)
106-
print(f"{self.tokenizer.decode([current_token])}", end="", flush=True)
110+
print(f"{self.tokenizer.decode_token(current_token)}", end="", flush=True)
107111
tokens = prompt_tokens + [current_token]
108112

109113
while len(tokens) < max_seq_len:
@@ -133,8 +137,7 @@ def generate( # noqa: C901
133137
):
134138
break
135139

136-
# print(f"{self.tokenizer.decode_token(current_token)}", end="", flush=True)
137-
print(f"{self.tokenizer.decode([current_token])}", end="", flush=True)
140+
print(f"{self.tokenizer.decode_token(current_token)}", end="", flush=True)
138141
print("\n")
139142

140143
return tokens if echo else tokens[len(prompt_tokens) :]
@@ -200,9 +203,7 @@ def chat_completion(
200203
# prompt_tokens = self.tokenizer.encode(
201204
# self._format_prompt(prompt), bos=True, eos=False
202205
# )
203-
prompt_tokens = self.tokenizer.encode(
204-
self._format_prompt(prompt)
205-
).ids
206+
prompt_tokens = self.tokenizer.encode(self._format_prompt(prompt)).ids
206207
generated_tokens = self.generate(
207208
prompt_tokens=pre_stop_token + prompt_tokens,
208209
max_seq_len=max_seq_len,

examples/models/llama/runner/native.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def __init__(self, args):
3737
params = json.loads(f.read())
3838
super().__init__(
3939
tokenizer_path=args.tokenizer,
40+
tokenizer_config_path=args.tokenizer_config,
4041
max_seq_len=args.max_len,
4142
max_batch_size=1,
4243
use_kv_cache=args.kv_cache,
@@ -56,6 +57,14 @@ def forward(
5657
)[0]
5758

5859

60+
def validate_args(args) -> None:
61+
if args.tokenizer and args.tokenizer.endswith(".json"):
62+
if not args.tokenizer_config:
63+
raise TypeError(
64+
"Json tokenizers require an accompanying tokenizer config (--tokenizer_config) to be specified."
65+
)
66+
67+
5968
def build_args_parser() -> argparse.ArgumentParser:
6069
# TODO: merge these with build_args_parser from export_llama_lib.
6170
parser = argparse.ArgumentParser()
@@ -85,6 +94,12 @@ def build_args_parser() -> argparse.ArgumentParser:
8594
default=None,
8695
)
8796

97+
parser.add_argument(
98+
"--tokenizer_config",
99+
type=str,
100+
default=None,
101+
)
102+
88103
parser.add_argument(
89104
"--prompt",
90105
type=str,
@@ -116,6 +131,7 @@ def build_args_parser() -> argparse.ArgumentParser:
116131
def main() -> None:
117132
parser = build_args_parser()
118133
args = parser.parse_args()
134+
validate_args(args)
119135
runner = NativeLlamaRunner(args)
120136
generated_tokens = runner.text_completion(
121137
prompt=args.prompt,

examples/models/llama3_2_vision/runner/generation.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
class TorchTuneLlamaRunner(LlamaRunner):
1414
def __init__(
1515
self,
16+
*,
1617
tokenizer_path: str,
1718
max_seq_len: int,
1819
max_batch_size: int,
@@ -21,12 +22,12 @@ def __init__(
2122
device: str = "cpu",
2223
):
2324
super().__init__(
24-
tokenizer_path,
25-
max_seq_len,
26-
max_batch_size,
27-
use_kv_cache,
28-
vocab_size,
29-
device,
25+
tokenizer_path=tokenizer_path,
26+
max_seq_len=max_seq_len,
27+
max_batch_size=max_batch_size,
28+
use_kv_cache=use_kv_cache,
29+
vocab_size=vocab_size,
30+
device=device,
3031
)
3132

3233
self.causal_mask = torch.tril(

extension/llm/tokenizer/hf_tokenizer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import re
44
from typing import Dict, List, Optional
55

6+
67
class HFTokenizer:
78
def __init__(self):
89
self.special_token_encoder: Dict[str, int] = {}

extension/llm/tokenizer/utils.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,29 +4,32 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import json
8+
from typing import Optional
9+
710
from executorch.examples.models.llama.tokenizer.tiktoken import Tokenizer as Tiktoken
11+
from executorch.extension.llm.tokenizer.hf_tokenizer import HFTokenizer
812
from executorch.extension.llm.tokenizer.tokenizer import (
913
Tokenizer as SentencePieceTokenizer,
1014
)
11-
from executorch.extension.llm.tokenizer.hf_tokenizer import HFTokenizer
1215

1316

14-
def get_tokenizer(tokenizer_path):
17+
def get_tokenizer(tokenizer_path: str, tokenizer_config_path: Optional[str] = None):
1518
if tokenizer_path.endswith(".json"):
16-
# print("Using Hugging Face tokenizer")
17-
# tokenizer = HFTokenizer()
18-
# tokenizer.load(tokenizer_path)
19-
2019
from tokenizers import Tokenizer
2120

2221
# Load the tokenizer from the tokenizer.json file
2322
tokenizer = Tokenizer.from_file(tokenizer_path)
24-
25-
# from tokenizers import SentencePieceBPETokenizer
2623

27-
# tokenizer = SentencePieceBPETokenizer(tokenizer_path)
24+
# export_llama expects n_words attribute.
2825
tokenizer.n_words = tokenizer.get_vocab_size()
29-
breakpoint()
26+
# Keep in line with internal tokenizer apis.
27+
tokenizer.decode_token = lambda token: tokenizer.decode([token])
28+
29+
if tokenizer_config_path:
30+
with open(tokenizer_config_path) as f:
31+
tokenizer_config = json.load(f)
32+
tokenizer.eos_id = tokenizer.token_to_id(tokenizer_config["eos_token"])
3033
else:
3134
try:
3235
tokenizer = SentencePieceTokenizer(model_path=str(tokenizer_path))

0 commit comments

Comments
 (0)