Skip to content

Commit 0763945

Browse files
authored
Migrate extension/llm/tokenizer python users to use the new repo
Differential Revision: D69820450 Pull Request resolved: #22
1 parent b3ba207 commit 0763945

File tree

8 files changed

+486
-119
lines changed

8 files changed

+486
-119
lines changed

pytorch_tokenizers/TARGETS

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# Any targets that should be shared between fbcode and xplat must be defined in
2+
# targets.bzl. This file can contain xplat-only targets.
3+
4+
load("@fbcode_macros//build_defs:python_library.bzl", "python_library")
5+
load(":targets.bzl", "define_common_targets")
6+
7+
oncall("executorch")
8+
9+
define_common_targets()
10+
11+
python_library(
12+
name = "hf_tokenizer",
13+
srcs = ["hf_tokenizer.py"],
14+
labels = ["autodeps2_generated"],
15+
deps = [
16+
"fbsource//third-party/pypi/tokenizers:tokenizers",
17+
],
18+
)

pytorch_tokenizers/__init__.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
# @lint-ignore-every LICENSELINT
7+
8+
9+
from typing import Optional
10+
11+
from .hf_tokenizer import HuggingFaceTokenizer
12+
from .llama2c import Llama2cTokenizer
13+
from .tiktoken import TiktokenTokenizer
14+
15+
__all__ = ["TiktokenTokenizer", "Llama2cTokenizer", "HuggingFaceTokenizer"]
16+
17+
18+
def get_tokenizer(tokenizer_path: str, tokenizer_config_path: Optional[str] = None):
19+
if tokenizer_path.endswith(".json"):
20+
tokenizer = HuggingFaceTokenizer(tokenizer_path, tokenizer_config_path)
21+
else:
22+
try:
23+
tokenizer = Llama2cTokenizer(model_path=str(tokenizer_path))
24+
except Exception:
25+
print("Using Tiktokenizer")
26+
tokenizer = TiktokenTokenizer(model_path=str(tokenizer_path))
27+
return tokenizer

pytorch_tokenizers/hf_tokenizer.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
# @lint-ignore-every LICENSELINT
7+
8+
import json
9+
import os
10+
from typing import List, Optional
11+
12+
from tokenizers import Tokenizer
13+
14+
15+
class HuggingFaceTokenizer:
16+
"""
17+
Tokenizing and encoding/decoding text using the Hugging face tokenizer.
18+
"""
19+
20+
def __init__(self, model_path: str, config_path: Optional[str] = None):
21+
"""
22+
Initializes the Tokenizer with a tokenizer.json from HuggingFace.
23+
24+
Args:
25+
model_path (str): The path to the Tiktoken model file.
26+
"""
27+
assert os.path.isfile(model_path), model_path
28+
29+
self.model = tokenizer = Tokenizer.from_file(model_path)
30+
31+
self.n_words: int = tokenizer.get_vocab_size()
32+
if config_path:
33+
with open(config_path) as f:
34+
tokenizer_config = json.load(f)
35+
self.bos_id = (
36+
self.model.token_to_id(tokenizer_config["bos_token"])
37+
if tokenizer_config["bos_token"]
38+
else None
39+
)
40+
self.eos_id = self.model.token_to_id(tokenizer_config["eos_token"])
41+
else: # Fallback guess.
42+
self.bos_id = self.model.token_to_id("<|begin_of_text|>")
43+
self.eos_id = self.model.token_to_id("<|endoftext|>")
44+
45+
self.stop_tokens = [
46+
self.eos_id,
47+
]
48+
49+
def encode(self, s: str, *, bos: bool, eos: bool) -> List[int]:
50+
assert type(s) is str
51+
return self.model.encode(s).ids
52+
53+
def decode(self, t: List[int]) -> str:
54+
return self.model.decode(t)
55+
56+
def decode_token(self, t: int) -> str:
57+
return self.model.decode([t])

pytorch_tokenizers/llama2c.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
# @lint-ignore-every LICENSELINT
7+
8+
import logging
9+
import os
10+
import struct
11+
from typing import List
12+
13+
from sentencepiece import SentencePieceProcessor as SentencePieceProcessor
14+
15+
16+
class Llama2cTokenizer:
17+
def __init__(self, model_path: str):
18+
assert os.path.isfile(
19+
model_path
20+
), f"Need a valid tokenizer model path but got {model_path}"
21+
# pyre-fixme[28]: Unexpected keyword argument `model_file` to call `SentencePieceProcessor.__init__`.
22+
self.sp_model = SentencePieceProcessor(model_file=model_path)
23+
self.model_path = model_path
24+
25+
# BOS / EOS token IDs
26+
self.n_words: int = self.sp_model.vocab_size()
27+
self.bos_id: int = self.sp_model.bos_id()
28+
self.eos_id: int = self.sp_model.eos_id()
29+
logging.info(
30+
f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}"
31+
)
32+
# pyre-fixme[16]: `SentencePieceProcessor` has no attribute `get_piece_size`.
33+
assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
34+
35+
def encode(self, s: str, bos: bool, eos: bool) -> List[int]:
36+
assert type(s) is str
37+
# pyre-fixme[16]: `SentencePieceProcessor` has no attribute `encode`.
38+
t = self.sp_model.encode(s)
39+
if bos:
40+
t = [self.bos_id] + t
41+
if eos:
42+
t = t + [self.eos_id]
43+
return t
44+
45+
def decode(self, t: List[int]) -> str:
46+
# pyre-fixme[16]: `SentencePieceProcessor` has no attribute `encode`.
47+
return self.sp_model.decode(t)
48+
49+
def decode_token(self, t: int) -> str:
50+
# pyre-fixme[16]: `SentencePieceProcessor` has no attribute `encode`.
51+
return self.sp_model.decode(t)
52+
53+
def export(self, output_path: str, *, prepend_padding: bool = False) -> None:
54+
"""
55+
Export tokenizer.model to another serialization format. Here we did some lightweight
56+
processing such as supporting prepend padding token, prepend max token length and
57+
replace '_' back to empty space.
58+
59+
The binary format is:
60+
1. vocab size: int32
61+
2. bos token id: int32
62+
3. eos token id: int32
63+
4. max token length: int32
64+
5. score: float32, len of bytes: int32, token bytes: [byte] for each token
65+
66+
:param output_path: output path of the new binary.
67+
:param prepend_padding: a boolean to control if we want to prepend a padding token.
68+
69+
:return: None
70+
"""
71+
72+
# get all the tokens (postprocessed) and their scores as floats
73+
tokens, scores = [], []
74+
75+
if prepend_padding:
76+
# Here we use the default padding token and its score.
77+
tokens.append("<pad>".encode("utf-8"))
78+
scores.append(-1)
79+
80+
for i in range(self.n_words):
81+
# decode the token and light postprocessing
82+
# pyre-fixme[16]: `SentencePieceProcessor` has no attribute `id_to_piece`.
83+
t = self.sp_model.id_to_piece(i)
84+
# pyre-fixme[16]: `SentencePieceProcessor` has no attribute `get_score`.
85+
s = self.sp_model.get_score(i)
86+
# sentencepiece use '<s>' as BOS and '</s>' for EOS
87+
if i == self.bos_id:
88+
t = "<s>"
89+
elif i == self.eos_id:
90+
t = "</s>"
91+
t = t.replace("▁", " ") # sentencepiece uses this character as whitespace
92+
b = t.encode("utf-8") # bytes of this token, utf-8 encoded
93+
94+
tokens.append(b)
95+
scores.append(s)
96+
97+
# record the max token length
98+
max_token_length = 0 if not tokens else max(len(t) for t in tokens)
99+
100+
# write to a binary file
101+
with open(output_path, "wb") as f:
102+
# write the vocab size, bos/eos ids and max token length
103+
f.write(
104+
struct.pack(
105+
"IIII", self.n_words, self.bos_id, self.eos_id, max_token_length
106+
)
107+
)
108+
for bytes, score in zip(tokens, scores):
109+
f.write(struct.pack("fI", score, len(bytes)))
110+
f.write(bytes)
111+
logging.info(f"Wrote tokenizer to {output_path}")

pytorch_tokenizers/targets.bzl

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2+
3+
def define_common_targets():
4+
"""Defines targets that should be shared between fbcode and xplat.
5+
6+
The directory containing this targets.bzl file should also contain both
7+
TARGETS and BUCK files that call this function.
8+
"""
9+
runtime.python_library(
10+
name = "tokenizers",
11+
srcs = [
12+
"__init__.py",
13+
"llama2c.py",
14+
"tiktoken.py",
15+
"hf_tokenizer.py",
16+
],
17+
base_module = "pytorch_tokenizers",
18+
visibility = [
19+
"//executorch/examples/...",
20+
"//executorch/extension/llm/export/...",
21+
"//bento/...",
22+
"//bento_kernels/...",
23+
"//pytorch/tokenizers/...",
24+
"@EXECUTORCH_CLIENTS",
25+
],
26+
_is_external_target = True,
27+
external_deps = [
28+
"sentencepiece-py",
29+
],
30+
deps = [
31+
"fbsource//third-party/pypi/tiktoken:tiktoken",
32+
"fbsource//third-party/pypi/tokenizers:tokenizers",
33+
],
34+
)

0 commit comments

Comments
 (0)