Skip to content

Migrate extension/llm/tokenizer python users to use the new repo #22

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions pytorch_tokenizers/TARGETS
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Any targets that should be shared between fbcode and xplat must be defined in
# targets.bzl. This file can contain xplat-only targets.

load("@fbcode_macros//build_defs:python_library.bzl", "python_library")
load(":targets.bzl", "define_common_targets")

oncall("executorch")

define_common_targets()

python_library(
name = "hf_tokenizer",
srcs = ["hf_tokenizer.py"],
labels = ["autodeps2_generated"],
deps = [
"fbsource//third-party/pypi/tokenizers:tokenizers",
],
)
27 changes: 27 additions & 0 deletions pytorch_tokenizers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# @lint-ignore-every LICENSELINT


from typing import Optional

from .hf_tokenizer import HuggingFaceTokenizer
from .llama2c import Llama2cTokenizer
from .tiktoken import TiktokenTokenizer

__all__ = ["TiktokenTokenizer", "Llama2cTokenizer", "HuggingFaceTokenizer"]


def get_tokenizer(tokenizer_path: str, tokenizer_config_path: Optional[str] = None):
if tokenizer_path.endswith(".json"):
tokenizer = HuggingFaceTokenizer(tokenizer_path, tokenizer_config_path)
else:
try:
tokenizer = Llama2cTokenizer(model_path=str(tokenizer_path))
except Exception:
print("Using Tiktokenizer")
tokenizer = TiktokenTokenizer(model_path=str(tokenizer_path))
return tokenizer
57 changes: 57 additions & 0 deletions pytorch_tokenizers/hf_tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# @lint-ignore-every LICENSELINT

import json
import os
from typing import List, Optional

from tokenizers import Tokenizer


class HuggingFaceTokenizer:
"""
Tokenizing and encoding/decoding text using the Hugging face tokenizer.
"""

def __init__(self, model_path: str, config_path: Optional[str] = None):
"""
Initializes the Tokenizer with a tokenizer.json from HuggingFace.

Args:
model_path (str): The path to the Tiktoken model file.
"""
assert os.path.isfile(model_path), model_path

self.model = tokenizer = Tokenizer.from_file(model_path)

self.n_words: int = tokenizer.get_vocab_size()
if config_path:
with open(config_path) as f:
tokenizer_config = json.load(f)
self.bos_id = (
self.model.token_to_id(tokenizer_config["bos_token"])
if tokenizer_config["bos_token"]
else None
)
self.eos_id = self.model.token_to_id(tokenizer_config["eos_token"])
else: # Fallback guess.
self.bos_id = self.model.token_to_id("<|begin_of_text|>")
self.eos_id = self.model.token_to_id("<|endoftext|>")

self.stop_tokens = [
self.eos_id,
]

def encode(self, s: str, *, bos: bool, eos: bool) -> List[int]:
assert type(s) is str
return self.model.encode(s).ids

def decode(self, t: List[int]) -> str:
return self.model.decode(t)

def decode_token(self, t: int) -> str:
return self.model.decode([t])
111 changes: 111 additions & 0 deletions pytorch_tokenizers/llama2c.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# @lint-ignore-every LICENSELINT

import logging
import os
import struct
from typing import List

from sentencepiece import SentencePieceProcessor as SentencePieceProcessor


class Llama2cTokenizer:
def __init__(self, model_path: str):
assert os.path.isfile(
model_path
), f"Need a valid tokenizer model path but got {model_path}"
# pyre-fixme[28]: Unexpected keyword argument `model_file` to call `SentencePieceProcessor.__init__`.
self.sp_model = SentencePieceProcessor(model_file=model_path)
self.model_path = model_path

# BOS / EOS token IDs
self.n_words: int = self.sp_model.vocab_size()
self.bos_id: int = self.sp_model.bos_id()
self.eos_id: int = self.sp_model.eos_id()
logging.info(
f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}"
)
# pyre-fixme[16]: `SentencePieceProcessor` has no attribute `get_piece_size`.
assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()

def encode(self, s: str, bos: bool, eos: bool) -> List[int]:
assert type(s) is str
# pyre-fixme[16]: `SentencePieceProcessor` has no attribute `encode`.
t = self.sp_model.encode(s)
if bos:
t = [self.bos_id] + t
if eos:
t = t + [self.eos_id]
return t

def decode(self, t: List[int]) -> str:
# pyre-fixme[16]: `SentencePieceProcessor` has no attribute `encode`.
return self.sp_model.decode(t)

def decode_token(self, t: int) -> str:
# pyre-fixme[16]: `SentencePieceProcessor` has no attribute `encode`.
return self.sp_model.decode(t)

def export(self, output_path: str, *, prepend_padding: bool = False) -> None:
"""
Export tokenizer.model to another serialization format. Here we did some lightweight
processing such as supporting prepend padding token, prepend max token length and
replace '_' back to empty space.

The binary format is:
1. vocab size: int32
2. bos token id: int32
3. eos token id: int32
4. max token length: int32
5. score: float32, len of bytes: int32, token bytes: [byte] for each token

:param output_path: output path of the new binary.
:param prepend_padding: a boolean to control if we want to prepend a padding token.

:return: None
"""

# get all the tokens (postprocessed) and their scores as floats
tokens, scores = [], []

if prepend_padding:
# Here we use the default padding token and its score.
tokens.append("<pad>".encode("utf-8"))
scores.append(-1)

for i in range(self.n_words):
# decode the token and light postprocessing
# pyre-fixme[16]: `SentencePieceProcessor` has no attribute `id_to_piece`.
t = self.sp_model.id_to_piece(i)
# pyre-fixme[16]: `SentencePieceProcessor` has no attribute `get_score`.
s = self.sp_model.get_score(i)
# sentencepiece use '<s>' as BOS and '</s>' for EOS
if i == self.bos_id:
t = "<s>"
elif i == self.eos_id:
t = "</s>"
t = t.replace("▁", " ") # sentencepiece uses this character as whitespace
b = t.encode("utf-8") # bytes of this token, utf-8 encoded

tokens.append(b)
scores.append(s)

# record the max token length
max_token_length = 0 if not tokens else max(len(t) for t in tokens)

# write to a binary file
with open(output_path, "wb") as f:
# write the vocab size, bos/eos ids and max token length
f.write(
struct.pack(
"IIII", self.n_words, self.bos_id, self.eos_id, max_token_length
)
)
for bytes, score in zip(tokens, scores):
f.write(struct.pack("fI", score, len(bytes)))
f.write(bytes)
logging.info(f"Wrote tokenizer to {output_path}")
34 changes: 34 additions & 0 deletions pytorch_tokenizers/targets.bzl
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")

def define_common_targets():
"""Defines targets that should be shared between fbcode and xplat.

The directory containing this targets.bzl file should also contain both
TARGETS and BUCK files that call this function.
"""
runtime.python_library(
name = "tokenizers",
srcs = [
"__init__.py",
"llama2c.py",
"tiktoken.py",
"hf_tokenizer.py",
],
base_module = "pytorch_tokenizers",
visibility = [
"//executorch/examples/...",
"//executorch/extension/llm/export/...",
"//bento/...",
"//bento_kernels/...",
"//pytorch/tokenizers/...",
"@EXECUTORCH_CLIENTS",
],
_is_external_target = True,
external_deps = [
"sentencepiece-py",
],
deps = [
"fbsource//third-party/pypi/tiktoken:tiktoken",
"fbsource//third-party/pypi/tokenizers:tokenizers",
],
)
Loading