Skip to content

Commit fabd5a1

Browse files
committed
Refactor tokenizer
Summary: Prepare to add tiktoken Test Plan: Rely on CI jobs Reviewers: Subscribers: Tasks: Tags:
1 parent 6251715 commit fabd5a1

File tree

5 files changed

+532
-287
lines changed

5 files changed

+532
-287
lines changed

runner-et/CMakeLists.txt

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ set(_common_include_directories ${TORCHCHAT_ROOT}/${ET_BUILD_DIR}/src)
3434
cmake_print_variables(_common_include_directories)
3535

3636
target_include_directories(executorch INTERFACE ${_common_include_directories}) # Ideally ExecuTorch installation process would do this
37-
add_executable(run run.cpp)
38-
37+
add_executable(run run.cpp ${TORCHCHAT_ROOT}/runner/bpe_tokenizer.cpp)
38+
target_include_directories(run PRIVATE ${TORCHCHAT_ROOT}/runner)
3939
# Link ET runtime + extensions
4040
target_link_libraries(
4141
run PRIVATE
@@ -55,9 +55,6 @@ target_link_libraries(
5555
)
5656
target_link_options_shared_lib(optimized_native_cpu_ops_lib)
5757
target_link_options_shared_lib(xnnpack_backend)
58-
target_link_options_shared_lib(XNNPACK)
59-
target_link_options_shared_lib(pthreadpool)
60-
target_link_options_shared_lib(cpuinfo)
6158
# Not clear why linking executorch as whole-archive outside android/apple is leading
6259
# to double registration. Most likely because of linkage issues.
6360
# Will figure this out later. Until then use this.

runner/bpe_tokenizer.cpp

Lines changed: 318 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,318 @@
1+
#include <tokenizer.h>
2+
3+
static int compare_tokens(const void* a, const void* b) {
4+
if (((TokenIndex*)a)->str == nullptr) {
5+
return -1;
6+
}
7+
if (((TokenIndex*)b)->str == nullptr) {
8+
return 1;
9+
}
10+
return strcmp(((TokenIndex*)a)->str, ((TokenIndex*)b)->str);
11+
}
12+
13+
BPETokenizer::BPETokenizer(
14+
int32_t vocab_size,
15+
uint64_t bos_tok,
16+
uint64_t eos_tok)
17+
: Tokenizer(vocab_size, bos_tok, eos_tok),
18+
vocab_(std::make_unique<char*[]>(vocab_size)),
19+
vocab_scores_(std::make_unique<float[]>(vocab_size)),
20+
sorted_vocab_(std::make_unique<TokenIndex[]>(vocab_size)) {
21+
for (int i = 0; i < 256; i++) {
22+
byte_pieces_[i * 2] = (unsigned char)i;
23+
byte_pieces_[i * 2 + 1] = '\0';
24+
}
25+
}
26+
27+
/**
28+
* @brief Load the tokenizer from a file. The tokenizer file contains the
29+
* vocabulary and scores. The format is: the first integer is the maximum
30+
* token length, followed by a list of (word_len, word) pairs. Here we
31+
* are reading all the vocabulary into memory and keep it sorted for fast
32+
* lookup.
33+
*
34+
* @param tokenizer_path The path to the tokenizer file.
35+
* @return void
36+
*/
37+
void BPETokenizer::load(const std::string& tokenizer_path) {
38+
if (initialized_) {
39+
fprintf(stderr, "Tokenizer already initialized.");
40+
return;
41+
}
42+
// read in the file
43+
FILE* file = fopen(tokenizer_path.c_str(), "rb");
44+
if (!file) {
45+
fprintf(stderr, "couldn't load %s", tokenizer_path.c_str());
46+
exit(EXIT_FAILURE);
47+
}
48+
int32_t metadata[2];
49+
for (int i = 0; i < 2; i++) {
50+
if (fread(metadata + i, sizeof(int32_t), 1, file) != 1) {
51+
fprintf(
52+
stderr,
53+
"Failed to read the metadata at position %d, the tokenizer file is not valid!",
54+
i);
55+
exit(EXIT_FAILURE);
56+
}
57+
}
58+
59+
// now we have two vocab_sizes one from the model and another from the
60+
// tokenizer file.
61+
int32_t tokenizer_vocab_size = metadata[0];
62+
if (tokenizer_vocab_size < vocab_size_) {
63+
fprintf(
64+
stderr,
65+
"The tokenizer vocab size %d is smaller than the model vocab size %d, will add padding tokens.",
66+
tokenizer_vocab_size,
67+
vocab_size_);
68+
} else if (tokenizer_vocab_size > vocab_size_) {
69+
fprintf(
70+
stderr,
71+
"The tokenizer vocab size %d is larger than the model vocab size %d.",
72+
tokenizer_vocab_size,
73+
vocab_size_);
74+
}
75+
76+
max_token_length_ = metadata[1];
77+
78+
// allocate space for the vocabulary
79+
vocab_ = std::make_unique<char*[]>(vocab_size_);
80+
vocab_scores_ = std::make_unique<float[]>(vocab_size_);
81+
sorted_vocab_ = std::make_unique<TokenIndex[]>(vocab_size_);
82+
83+
// read in the vocabulary
84+
for (int i = 0; i < vocab_size_; i++) {
85+
if (fread(vocab_scores_.get() + i, sizeof(float), 1, file) != 1) {
86+
// This is allowed, we just pad the rest of the vocab with <pad> strings
87+
std::string padding = "<pad>";
88+
vocab_[i] = new char[padding.length() + 1];
89+
strcpy(vocab_[i], padding.c_str());
90+
vocab_[i][padding.length()] = '\0';
91+
continue;
92+
}
93+
int32_t len;
94+
if (fread(&len, sizeof(int32_t), 1, file) != 1) {
95+
fprintf(stderr, "Failed to read the length of the word at index %d", i);
96+
exit(EXIT_FAILURE);
97+
}
98+
vocab_[i] = new char[len + 1];
99+
if (fread(vocab_[i], len, 1, file) != 1) {
100+
fprintf(
101+
stderr,
102+
"Failed to read the word, total length %d, index %d\n",
103+
len,
104+
i);
105+
exit(EXIT_FAILURE);
106+
}
107+
vocab_[i][len] = '\0'; // add the string terminating token
108+
}
109+
fclose(file);
110+
111+
for (int32_t i = 0; i < vocab_size_; i++) {
112+
sorted_vocab_[i].str = vocab_[i];
113+
sorted_vocab_[i].id = i;
114+
}
115+
qsort(sorted_vocab_.get(), vocab_size_, sizeof(TokenIndex), compare_tokens);
116+
117+
initialized_ = true;
118+
}
119+
120+
BPETokenizer::~BPETokenizer() {
121+
for (int i = 0; i < vocab_size_; i++) {
122+
delete[] vocab_[i];
123+
}
124+
}
125+
126+
/**
127+
* @brief Decode a token into string.
128+
*
129+
* @param prev_token The previous token.
130+
* @param token The current token.
131+
* @return std::string A pointer to the string representation of the
132+
* token.
133+
*/
134+
std::string BPETokenizer::decode(uint64_t prev_token, uint64_t token) {
135+
if (!initialized_) {
136+
fprintf(stderr, "Tokenizer not initialized");
137+
exit(EXIT_FAILURE);
138+
}
139+
const char* piece = vocab_[token];
140+
// following BOS token, sentencepiece decoder strips any leading
141+
// whitespace
142+
if (prev_token == bos_tok_ && piece[0] == ' ') {
143+
piece++;
144+
}
145+
// careful, some tokens designate raw bytes, and look like e.g. '<0x01>'
146+
// parse this and convert and return the actual byte
147+
unsigned char byte_val;
148+
if (sscanf(piece, "<0x%02hhX>", &byte_val) == 1) {
149+
piece = (char*)byte_pieces_ + byte_val * 2;
150+
}
151+
std::string res(piece);
152+
return res;
153+
}
154+
155+
static int32_t
156+
str_lookup(const char* str, TokenIndex* sorted_vocab, int32_t vocab_size) {
157+
// efficiently find the perfect match for str in vocab, return its index or -1
158+
// if not found
159+
TokenIndex tok = {.str = str}; // acts as the key to search for
160+
TokenIndex* res = (TokenIndex*)bsearch(
161+
&tok, sorted_vocab, vocab_size, sizeof(TokenIndex), compare_tokens);
162+
return res != nullptr ? res->id : -1;
163+
}
164+
165+
/**
166+
* @brief Encode a string into a sequence of tokens.
167+
*
168+
* @param text The string to be encoded.
169+
* @param bos The number of BOS to prepend to the token list.
170+
* @param eos The number of EOS to append to the token list.
171+
* @param tokens The output tokens.
172+
* @param n_tokens The number of tokens.
173+
* @return std::vector<uint64_t>
174+
*/
175+
std::vector<uint64_t>
176+
BPETokenizer::encode(const std::string& text, int8_t bos, int8_t eos) {
177+
if (!initialized_) {
178+
fprintf(stderr, "Tokenizer not initialized");
179+
exit(EXIT_FAILURE);
180+
}
181+
// encode the string text (input) into an upper-bound preallocated tokens[]
182+
// array bos != 0 means prepend the BOS token (=1), eos != 0 means append the
183+
// EOS token (=2)
184+
if (text.empty()) {
185+
fprintf(stderr, "cannot encode empty text");
186+
exit(EXIT_FAILURE);
187+
}
188+
189+
// create a temporary buffer that will store merge candidates of always two
190+
// consecutive tokens *2 for concat, +1 for null terminator +2 for UTF8 (in
191+
// case max_token_length is 1)
192+
char* str_buffer = new char[max_token_length_ * 2 + 1 + 2];
193+
size_t str_len = 0;
194+
195+
// start at 0 tokens
196+
std::vector<uint64_t> tokens;
197+
198+
// add optional BOS token, if desired
199+
if (bos > 0) {
200+
while (bos--) {
201+
tokens.push_back(bos_tok_);
202+
}
203+
} else {
204+
fprintf(stderr, "bos %d should be >= 0", bos);
205+
exit(EXIT_FAILURE);
206+
}
207+
208+
// add_dummy_prefix is true by default
209+
// so prepend a dummy prefix token to the input string, but only if text != ""
210+
// TODO: pretty sure this isn't correct in the general case but I don't have
211+
// the energy to read more of the sentencepiece code to figure out what it's
212+
// doing
213+
const char* space = " ";
214+
if (text[0] != '\0') {
215+
int dummy_prefix = str_lookup(space, sorted_vocab_.get(), vocab_size_);
216+
tokens.push_back(dummy_prefix);
217+
}
218+
219+
// Okay UTF-8 time. This will get messy. Here is the reference from Uncyclopedia:
220+
// Code point ↔ UTF-8 conversion
221+
// First code point Last code point Byte 1 Byte 2 Byte 3 Byte 4
222+
// U+0000 U+007F 0xxxxxxx
223+
// U+0080 U+07FF 110xxxxx 10xxxxxx
224+
// U+0800 U+FFFF 1110xxxx 10xxxxxx 10xxxxxx
225+
// U+10000 U+10FFFF 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx
226+
227+
// process the raw (UTF-8) byte sequence of the input string
228+
for (const char* c = text.c_str(); *c != '\0'; c++) {
229+
// reset buffer if the current byte is ASCII or a leading byte
230+
// 0xC0 is 11000000, so (*c & 0xC0) keeps the first 2 bits and zeros the
231+
// rest 0x80 is 10000000 in UTF-8, all continuation bytes start with "10" in
232+
// first two bits so in English this is: "if this byte is not a continuation
233+
// byte"
234+
if ((*c & 0xC0) != 0x80) {
235+
// this byte must be either a leading byte (11...) or an ASCII char
236+
// (0x...)
237+
// => reset our location, as we're starting a new UTF-8 codepoint
238+
str_len = 0;
239+
}
240+
241+
// append the current byte to the buffer
242+
str_buffer[str_len++] =
243+
*c; // ++ is post-increment, incremented after this line
244+
str_buffer[str_len] = '\0';
245+
246+
// while the next character is a continuation byte, continue appending
247+
// but if there are too many of them, just stop to avoid overruning
248+
// str_buffer size.
249+
if ((*(c + 1) & 0xC0) == 0x80 && str_len < 4) {
250+
continue;
251+
}
252+
253+
// ok c+1 is not a continuation byte, so we've read in a full codepoint
254+
int id = str_lookup(str_buffer, sorted_vocab_.get(), vocab_size_);
255+
if (id != -1) {
256+
// we found this codepoint in vocab, add it as a token
257+
tokens.push_back(id);
258+
} else {
259+
// byte_fallback encoding: just encode each byte as a token
260+
// +3 is here because the first 3 vocab elements are <unk>, <s>, </s>
261+
// so the individual bytes only start at index 3
262+
for (int i = 0; i < str_len; i++) {
263+
tokens.push_back((unsigned char)str_buffer[i] + 3);
264+
}
265+
}
266+
str_len = 0; // protect against a sequence of stray UTF8 continuation bytes
267+
}
268+
269+
// merge the best consecutive pair each iteration, according the scores in
270+
// vocab_scores
271+
while (1) {
272+
float best_score = -1e10;
273+
int best_id = -1;
274+
int best_idx = -1;
275+
276+
for (int i = 0; i < tokens.size() - 1; i++) {
277+
// check if we can merge the pair (tokens[i], tokens[i+1])
278+
snprintf(
279+
str_buffer,
280+
max_token_length_ * 2 + 3,
281+
"%s%s",
282+
vocab_[tokens[i]],
283+
vocab_[tokens[i + 1]]);
284+
int id = str_lookup(str_buffer, sorted_vocab_.get(), vocab_size_);
285+
if (id != -1 && vocab_scores_[id] > best_score) {
286+
// this merge pair exists in vocab! record its score and position
287+
best_score = vocab_scores_[id];
288+
best_id = id;
289+
best_idx = i;
290+
}
291+
}
292+
293+
if (best_idx == -1) {
294+
break; // we couldn't find any more pairs to merge, so we're done
295+
}
296+
297+
// merge the consecutive pair (best_idx, best_idx+1) into new token best_id
298+
tokens[best_idx] = best_id;
299+
// delete token at position best_idx+1, shift the entire sequence back 1
300+
for (int i = best_idx + 1; i < tokens.size() - 1; i++) {
301+
tokens[i] = tokens[i + 1];
302+
}
303+
tokens.pop_back(); // token length decreased
304+
}
305+
306+
// add optional EOS (=2) token, if desired
307+
if (eos >= 0) {
308+
while (eos--) {
309+
tokens.push_back(eos_tok_);
310+
}
311+
} else {
312+
fprintf(stderr, "eos %d should be >= 0", eos);
313+
exit(EXIT_FAILURE);
314+
}
315+
316+
delete[] str_buffer;
317+
return tokens;
318+
}

0 commit comments

Comments
 (0)