Skip to content

Python hugging face tokenizer #8354

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Feb 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 5 additions & 7 deletions examples/models/llama/install_requirements.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,12 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# Install sentencepiece for llama tokenizer.
# Install tiktoken for tokenizer.
# Install tokenizers for hf .json tokenizer.
# Install snakeviz for cProfile flamegraph
# Install sentencepiece for llama tokenizer
pip install snakeviz sentencepiece

# Install lm-eval for Model Evaluation with lm-evalution-harness
# Install tiktoken for tokenizer
pip install lm_eval==0.4.5
pip install tiktoken blobfile
# Install lm-eval for Model Evaluation with lm-evalution-harness.
pip install tiktoken sentencepiece tokenizers snakeviz lm_eval==0.4.5 blobfile

# Call the install helper for further setup
python examples/models/llama/install_requirement_helper.py
8 changes: 8 additions & 0 deletions examples/models/llama/runner/eager.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def __init__(self, args):
params = json.loads(f.read())
super().__init__(
tokenizer_path=args.tokenizer_path,
tokenizer_config_path=args.tokenizer_config_path,
max_seq_len=args.max_seq_length,
max_batch_size=1,
use_kv_cache=args.use_kv_cache,
Expand Down Expand Up @@ -74,6 +75,13 @@ def build_args_parser() -> argparse.ArgumentParser:
help="Have multi-turn chat with the model",
)

parser.add_argument(
"--tokenizer_config_path",
type=str,
default=None,
help="Path to an accompanying tokenizer_config.json, which provides metadata for the main tokenizer.json",
)

return parser


Expand Down
22 changes: 14 additions & 8 deletions examples/models/llama/runner/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ def next_token(logits: torch.Tensor, temperature: float, top_p: float) -> int:
class LlamaRunner(ABC):
def __init__(
self,
*,
tokenizer_path: str,
tokenizer_config_path: Optional[str] = None,
max_seq_len: int,
max_batch_size: int,
use_kv_cache: bool,
Expand All @@ -59,19 +61,23 @@ def __init__(
Constructor.

Args:
tokenizer_path: path to tokenizer.model file.
max_seq_len: max length of the output sequence, after which the output will be clipped.
max_batch_size: max batch size.
use_kv_cache: whether to use a KV cache.
vocab_size: number of items in the vocab.
device: device to run the runner on.
tokenizer_path: path to tokenizer.model file.
max_seq_len: max length of the output sequence, after which the output will be clipped.
max_batch_size: max batch size.
use_kv_cache: whether to use a KV cache.
vocab_size: number of items in the vocab.
device: device to run the runner on.
"""
self.max_seq_len = max_seq_len
self.max_batch_size = max_batch_size
self.use_kv_cache = use_kv_cache
self.tokenizer = get_tokenizer(tokenizer_path)
self.tokenizer = get_tokenizer(tokenizer_path, tokenizer_config_path)
self.device = device
assert vocab_size == self.tokenizer.n_words
# For some models like qwen, mismatch is acceptable: https://github.com/QwenLM/Qwen2.5/issues/466#issuecomment-2146759706
if vocab_size != self.tokenizer.n_words:
print(
"Warning - given vocab_size in params is unequal to tokenizer vocab size."
)

@abstractmethod
def forward(
Expand Down
17 changes: 17 additions & 0 deletions examples/models/llama/runner/native.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(self, args):
params = json.loads(f.read())
super().__init__(
tokenizer_path=args.tokenizer,
tokenizer_config_path=args.tokenizer_config,
max_seq_len=args.max_len,
max_batch_size=1,
use_kv_cache=args.kv_cache,
Expand All @@ -56,6 +57,14 @@ def forward(
)[0]


def validate_args(args) -> None:
if args.tokenizer and args.tokenizer.endswith(".json"):
if not args.tokenizer_config:
raise TypeError(
"Json tokenizers require an accompanying tokenizer config (--tokenizer_config) to be specified."
)


def build_args_parser() -> argparse.ArgumentParser:
# TODO: merge these with build_args_parser from export_llama_lib.
parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -85,6 +94,13 @@ def build_args_parser() -> argparse.ArgumentParser:
default=None,
)

parser.add_argument(
"--tokenizer_config",
type=str,
default=None,
help="Path to an accompanying tokenizer_config.json, which provides metadata for the main tokenizer.json",
)

parser.add_argument(
"--prompt",
type=str,
Expand Down Expand Up @@ -116,6 +132,7 @@ def build_args_parser() -> argparse.ArgumentParser:
def main() -> None:
parser = build_args_parser()
args = parser.parse_args()
validate_args(args)
runner = NativeLlamaRunner(args)
generated_tokens = runner.text_completion(
prompt=args.prompt,
Expand Down
13 changes: 7 additions & 6 deletions examples/models/llama3_2_vision/runner/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
class TorchTuneLlamaRunner(LlamaRunner):
def __init__(
self,
*,
tokenizer_path: str,
max_seq_len: int,
max_batch_size: int,
Expand All @@ -21,12 +22,12 @@ def __init__(
device: str = "cpu",
):
super().__init__(
tokenizer_path,
max_seq_len,
max_batch_size,
use_kv_cache,
vocab_size,
device,
tokenizer_path=tokenizer_path,
max_seq_len=max_seq_len,
max_batch_size=max_batch_size,
use_kv_cache=use_kv_cache,
vocab_size=vocab_size,
device=device,
)

self.causal_mask = torch.tril(
Expand Down
56 changes: 56 additions & 0 deletions extension/llm/tokenizer/hf_tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import json
import os
from typing import List, Optional

from tokenizers import Tokenizer


class HuggingFaceTokenizer:
"""
Tokenizing and encoding/decoding text using the Hugging face tokenizer.
"""

def __init__(self, model_path: str, config_path: Optional[str] = None):
"""
Initializes the Tokenizer with a tokenizer.json from HuggingFace.

Args:
model_path (str): The path to the Tiktoken model file.
"""
assert os.path.isfile(model_path), model_path

self.model = tokenizer = Tokenizer.from_file(model_path)

self.n_words: int = tokenizer.get_vocab_size()
if config_path:
with open(config_path) as f:
tokenizer_config = json.load(f)
self.bos_id = (
self.model.token_to_id(tokenizer_config["bos_token"])
if tokenizer_config["bos_token"]
else None
)
self.eos_id = self.model.token_to_id(tokenizer_config["eos_token"])
else: # Fallback guess.
self.bos_id = self.model.token_to_id("<|begin_of_text|>")
self.eos_id = self.model.token_to_id("<|endoftext|>")

self.stop_tokens = [
self.eos_id,
]

def encode(self, s: str, *, bos: bool, eos: bool) -> List[int]:
assert type(s) is str
return self.model.encode(s).ids

def decode(self, t: List[int]) -> str:
return self.model.decode(t)

def decode_token(self, t: int) -> str:
return self.model.decode([t])
19 changes: 13 additions & 6 deletions extension/llm/tokenizer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,23 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Optional

from executorch.examples.models.llama.tokenizer.tiktoken import Tokenizer as Tiktoken
from executorch.extension.llm.tokenizer.tokenizer import (
Tokenizer as SentencePieceTokenizer,
)


def get_tokenizer(tokenizer_path):
try:
tokenizer = SentencePieceTokenizer(model_path=str(tokenizer_path))
except Exception:
print("Using Tiktokenizer")
tokenizer = Tiktoken(model_path=str(tokenizer_path))
def get_tokenizer(tokenizer_path: str, tokenizer_config_path: Optional[str] = None):
if tokenizer_path.endswith(".json"):
from executorch.extension.llm.tokenizer.hf_tokenizer import HuggingFaceTokenizer

tokenizer = HuggingFaceTokenizer(tokenizer_path, tokenizer_config_path)
else:
try:
tokenizer = SentencePieceTokenizer(model_path=str(tokenizer_path))
except Exception:
print("Using Tiktokenizer")
tokenizer = Tiktoken(model_path=str(tokenizer_path))
return tokenizer
Loading