Skip to content

Commit 0e01e7c

Browse files
larryliu0820facebook-github-bot
authored andcommitted
Migrate helios' usage of extension/llm/tokenizer to pytorch/tokenizers
Summary: As titled. Differential Revision: D69885635
1 parent 0763945 commit 0e01e7c

File tree

6 files changed

+190
-0
lines changed

6 files changed

+190
-0
lines changed
File renamed without changes.
File renamed without changes.
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
8+
# Script to rewrite tokenizer model given by sentencepiece to llama2.c format, with lightweight
9+
# postprocessing logic. The output can be consumed by llama2c_tokenizer.cpp.
10+
11+
import argparse
12+
import logging
13+
import os
14+
import struct
15+
from typing import List
16+
17+
from sentencepiece import SentencePieceProcessor as SentencePieceProcessor
18+
19+
20+
class Tokenizer:
21+
def __init__(self, model_path: str):
22+
assert os.path.isfile(
23+
model_path
24+
), f"Need a valid tokenizer model path but got {model_path}"
25+
# pyre-fixme[28]: Unexpected keyword argument `model_file` to call `SentencePieceProcessor.__init__`.
26+
self.sp_model = SentencePieceProcessor(model_file=model_path)
27+
self.model_path = model_path
28+
29+
# BOS / EOS token IDs
30+
self.n_words: int = self.sp_model.vocab_size()
31+
self.bos_id: int = self.sp_model.bos_id()
32+
self.eos_id: int = self.sp_model.eos_id()
33+
logging.info(
34+
f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}"
35+
)
36+
# pyre-fixme[16]: `SentencePieceProcessor` has no attribute `get_piece_size`.
37+
assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
38+
39+
def encode(self, s: str, bos: bool, eos: bool) -> List[int]:
40+
assert type(s) is str
41+
# pyre-fixme[16]: `SentencePieceProcessor` has no attribute `encode`.
42+
t = self.sp_model.encode(s)
43+
if bos:
44+
t = [self.bos_id] + t
45+
if eos:
46+
t = t + [self.eos_id]
47+
return t
48+
49+
def decode(self, t: List[int]) -> str:
50+
# pyre-fixme[16]: `SentencePieceProcessor` has no attribute `encode`.
51+
return self.sp_model.decode(t)
52+
53+
def decode_token(self, t: int) -> str:
54+
# pyre-fixme[16]: `SentencePieceProcessor` has no attribute `encode`.
55+
return self.sp_model.decode(t)
56+
57+
def export(self, output_path: str, *, prepend_padding: bool = False) -> None:
58+
"""
59+
Export tokenizer.model to another serialization format. Here we did some lightweight
60+
processing such as supporting prepend padding token, prepend max token length and
61+
replace '_' back to empty space.
62+
63+
The binary format is:
64+
1. vocab size: int32
65+
2. bos token id: int32
66+
3. eos token id: int32
67+
4. max token length: int32
68+
5. score: float32, len of bytes: int32, token bytes: [byte] for each token
69+
70+
:param output_path: output path of the new binary.
71+
:param prepend_padding: a boolean to control if we want to prepend a padding token.
72+
73+
:return: None
74+
"""
75+
76+
# get all the tokens (postprocessed) and their scores as floats
77+
tokens, scores = [], []
78+
79+
if prepend_padding:
80+
# Here we use the default padding token and its score.
81+
tokens.append("<pad>".encode("utf-8"))
82+
scores.append(-1)
83+
84+
for i in range(self.n_words):
85+
# decode the token and light postprocessing
86+
# pyre-fixme[16]: `SentencePieceProcessor` has no attribute `id_to_piece`.
87+
t = self.sp_model.id_to_piece(i)
88+
# pyre-fixme[16]: `SentencePieceProcessor` has no attribute `get_score`.
89+
s = self.sp_model.get_score(i)
90+
# sentencepiece use '<s>' as BOS and '</s>' for EOS
91+
if i == self.bos_id:
92+
t = "<s>"
93+
elif i == self.eos_id:
94+
t = "</s>"
95+
t = t.replace("▁", " ") # sentencepiece uses this character as whitespace
96+
b = t.encode("utf-8") # bytes of this token, utf-8 encoded
97+
98+
tokens.append(b)
99+
scores.append(s)
100+
101+
# record the max token length
102+
max_token_length = 0 if not tokens else max(len(t) for t in tokens)
103+
104+
# write to a binary file
105+
with open(output_path, "wb") as f:
106+
# write the vocab size, bos/eos ids and max token length
107+
f.write(
108+
struct.pack(
109+
"IIII", self.n_words, self.bos_id, self.eos_id, max_token_length
110+
)
111+
)
112+
for bytes, score in zip(tokens, scores):
113+
f.write(struct.pack("fI", score, len(bytes)))
114+
f.write(bytes)
115+
logging.info(f"Wrote tokenizer to {output_path}")
116+
117+
118+
if __name__ == "__main__":
119+
parser = argparse.ArgumentParser()
120+
parser.add_argument(
121+
"-t",
122+
"--tokenizer-model",
123+
type=str,
124+
default="tokenizer.model",
125+
help="path to tokenizer model, given by sentencepiece",
126+
)
127+
parser.add_argument(
128+
"-o",
129+
"--output-path",
130+
type=str,
131+
default=None,
132+
help="output path of postprocessed tokenizer model",
133+
)
134+
parser.add_argument(
135+
"-p",
136+
"--prepend-padding",
137+
action="store_true",
138+
help="whether to prepend a padding token to the beginning of the tokenizer",
139+
)
140+
141+
args = parser.parse_args()
142+
143+
t = Tokenizer(args.tokenizer_model)
144+
145+
output_path = (
146+
args.output_path
147+
if args.output_path
148+
else args.tokenizer_model.replace(".model", ".bin")
149+
)
150+
t.export(output_path, prepend_padding=args.prepend_padding)
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2+
3+
def define_common_targets():
4+
"""Defines targets that should be shared between fbcode and xplat.
5+
6+
The directory containing this targets.bzl file should also contain both
7+
TARGETS and BUCK files that call this function.
8+
"""
9+
runtime.python_library(
10+
name = "convert_lib",
11+
srcs = [
12+
"__init__.py",
13+
"convert.py",
14+
],
15+
base_module = "pytorch_tokenizers.tools.llama2c",
16+
visibility = [
17+
"//executorch/examples/...",
18+
"//executorch/extension/llm/export/...",
19+
"//bento/...",
20+
"//bento_kernels/...",
21+
"@EXECUTORCH_CLIENTS",
22+
],
23+
_is_external_target = True,
24+
external_deps = [
25+
"sentencepiece-py",
26+
],
27+
)
28+
29+
runtime.python_binary(
30+
name = "convert",
31+
main_module = "pytorch_tokenizers.tools.llama2c.convert",
32+
visibility = [
33+
"//executorch/examples/...",
34+
"fbsource//xplat/executorch/examples/...",
35+
],
36+
_is_external_target = True,
37+
deps = [
38+
":convert_lib",
39+
],
40+
)

0 commit comments

Comments
 (0)