Skip to content

Commit 7e779bf

Browse files
larryliu0820malfet
authored andcommitted
Refactor tokenizer (#408)
* Refactor tokenizer Summary: Prepare to add tiktoken Test Plan: Rely on CI jobs Reviewers: Subscribers: Tasks: Tags: * Fix aoti Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * Fix more aoti Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * Fix tokenizer read format Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent 3ff18bd commit 7e779bf

File tree

7 files changed

+517
-288
lines changed

7 files changed

+517
-288
lines changed

.github/workflows/pull.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -794,6 +794,8 @@ jobs:
794794
popd
795795
- name: Run inference
796796
run: |
797+
set -eou pipefail
798+
797799
export MODEL_DIR=${PWD}/checkpoints/stories15M
798800
export PROMPT="Once upon a time in a land far away"
799801

runner-aoti/CMakeLists.txt

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,17 @@
11
cmake_minimum_required(VERSION 3.18 FATAL_ERROR)
22
project(llama2so LANGUAGES CXX)
3+
IF(DEFINED ENV{TORCHCHAT_ROOT})
4+
set(TORCHCHAT_ROOT $ENV{TORCHCHAT_ROOT})
5+
ELSE()
6+
set(TORCHCHAT_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/..)
7+
ENDIF()
38

49
find_package(CUDA)
510

611
find_package(Torch REQUIRED)
712
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g ${TORCH_CXX_FLAGS} -fpermissive")
813

9-
add_executable(run run.cpp)
14+
add_executable(run run.cpp ${TORCHCHAT_ROOT}/runner/bpe_tokenizer.cpp)
15+
target_include_directories(run PRIVATE ${TORCHCHAT_ROOT}/runner)
1016
target_link_libraries(run "${TORCH_LIBRARIES}" m)
1117
set_property(TARGET run PROPERTY CXX_STANDARD 17)

runner-et/CMakeLists.txt

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ set(_common_include_directories ${TORCHCHAT_ROOT}/${ET_BUILD_DIR}/src)
3434
cmake_print_variables(_common_include_directories)
3535

3636
target_include_directories(executorch INTERFACE ${_common_include_directories}) # Ideally ExecuTorch installation process would do this
37-
add_executable(run run.cpp)
38-
37+
add_executable(run run.cpp ${TORCHCHAT_ROOT}/runner/bpe_tokenizer.cpp)
38+
target_include_directories(run PRIVATE ${TORCHCHAT_ROOT}/runner)
3939
# Link ET runtime + extensions
4040
target_link_libraries(
4141
run PRIVATE
@@ -55,9 +55,6 @@ target_link_libraries(
5555
)
5656
target_link_options_shared_lib(optimized_native_cpu_ops_lib)
5757
target_link_options_shared_lib(xnnpack_backend)
58-
target_link_options_shared_lib(XNNPACK)
59-
target_link_options_shared_lib(pthreadpool)
60-
target_link_options_shared_lib(cpuinfo)
6158
# Not clear why linking executorch as whole-archive outside android/apple is leading
6259
# to double registration. Most likely because of linkage issues.
6360
# Will figure this out later. Until then use this.

runner/bpe_tokenizer.cpp

Lines changed: 294 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,294 @@
1+
#include <tokenizer.h>
2+
3+
static int compare_tokens(const void* a, const void* b) {
4+
if (((TokenIndex*)a)->str == nullptr) {
5+
return -1;
6+
}
7+
if (((TokenIndex*)b)->str == nullptr) {
8+
return 1;
9+
}
10+
return strcmp(((TokenIndex*)a)->str, ((TokenIndex*)b)->str);
11+
}
12+
13+
BPETokenizer::BPETokenizer(
14+
int32_t vocab_size,
15+
uint64_t bos_tok,
16+
uint64_t eos_tok)
17+
: Tokenizer(vocab_size, bos_tok, eos_tok),
18+
vocab_(std::make_unique<char*[]>(vocab_size)),
19+
vocab_scores_(std::make_unique<float[]>(vocab_size)),
20+
sorted_vocab_(std::make_unique<TokenIndex[]>(vocab_size)) {
21+
for (int i = 0; i < 256; i++) {
22+
byte_pieces_[i * 2] = (unsigned char)i;
23+
byte_pieces_[i * 2 + 1] = '\0';
24+
}
25+
}
26+
27+
/**
28+
* @brief Load the tokenizer from a file. The tokenizer file contains the
29+
* vocabulary and scores. The format is: the first integer is the maximum
30+
* token length, followed by a list of (word_len, word) pairs. Here we
31+
* are reading all the vocabulary into memory and keep it sorted for fast
32+
* lookup.
33+
*
34+
* @param tokenizer_path The path to the tokenizer file.
35+
* @return void
36+
*/
37+
void BPETokenizer::load(const std::string& tokenizer_path) {
38+
if (initialized_) {
39+
fprintf(stderr, "Tokenizer already initialized.\n");
40+
return;
41+
}
42+
// read in the file
43+
FILE* file = fopen(tokenizer_path.c_str(), "rb");
44+
if (!file) {
45+
fprintf(stderr, "couldn't load %s\n", tokenizer_path.c_str());
46+
exit(EXIT_FAILURE);
47+
}
48+
if (fread(&max_token_length_, sizeof(int32_t), 1, file) != 1) {
49+
fprintf(
50+
stderr,
51+
"Failed to read the max token length, the tokenizer file is not valid!\n");
52+
exit(EXIT_FAILURE);
53+
}
54+
// allocate space for the vocabulary
55+
vocab_ = std::make_unique<char*[]>(vocab_size_);
56+
vocab_scores_ = std::make_unique<float[]>(vocab_size_);
57+
sorted_vocab_ = std::make_unique<TokenIndex[]>(vocab_size_);
58+
59+
// read in the vocabulary
60+
for (int i = 0; i < vocab_size_; i++) {
61+
if (fread(vocab_scores_.get() + i, sizeof(float), 1, file) != 1) {
62+
// This is allowed, we just pad the rest of the vocab with <pad> strings
63+
std::string padding = "<pad>";
64+
vocab_[i] = new char[padding.length() + 1];
65+
strcpy(vocab_[i], padding.c_str());
66+
vocab_[i][padding.length()] = '\0';
67+
continue;
68+
}
69+
int32_t len;
70+
if (fread(&len, sizeof(int32_t), 1, file) != 1) {
71+
fprintf(stderr, "Failed to read the length of the word at index %d\n", i);
72+
exit(EXIT_FAILURE);
73+
}
74+
vocab_[i] = new char[len + 1];
75+
if (fread(vocab_[i], len, 1, file) != 1) {
76+
fprintf(
77+
stderr,
78+
"Failed to read the word, total length %d, index %d\n",
79+
len,
80+
i);
81+
exit(EXIT_FAILURE);
82+
}
83+
vocab_[i][len] = '\0'; // add the string terminating token
84+
}
85+
fclose(file);
86+
87+
for (int32_t i = 0; i < vocab_size_; i++) {
88+
sorted_vocab_[i].str = vocab_[i];
89+
sorted_vocab_[i].id = i;
90+
}
91+
qsort(sorted_vocab_.get(), vocab_size_, sizeof(TokenIndex), compare_tokens);
92+
93+
initialized_ = true;
94+
}
95+
96+
BPETokenizer::~BPETokenizer() {
97+
for (int i = 0; i < vocab_size_; i++) {
98+
delete[] vocab_[i];
99+
}
100+
}
101+
102+
/**
103+
* @brief Decode a token into string.
104+
*
105+
* @param prev_token The previous token.
106+
* @param token The current token.
107+
* @return std::string A pointer to the string representation of the
108+
* token.
109+
*/
110+
std::string BPETokenizer::decode(uint64_t prev_token, uint64_t token) {
111+
if (!initialized_) {
112+
fprintf(stderr, "Tokenizer not initialized\n");
113+
exit(EXIT_FAILURE);
114+
}
115+
const char* piece = vocab_[token];
116+
// following BOS token, sentencepiece decoder strips any leading
117+
// whitespace
118+
if (prev_token == bos_tok_ && piece[0] == ' ') {
119+
piece++;
120+
}
121+
// careful, some tokens designate raw bytes, and look like e.g. '<0x01>'
122+
// parse this and convert and return the actual byte
123+
unsigned char byte_val;
124+
if (sscanf(piece, "<0x%02hhX>", &byte_val) == 1) {
125+
piece = (char*)byte_pieces_ + byte_val * 2;
126+
}
127+
std::string res(piece);
128+
return res;
129+
}
130+
131+
static int32_t
132+
str_lookup(const char* str, TokenIndex* sorted_vocab, int32_t vocab_size) {
133+
// efficiently find the perfect match for str in vocab, return its index or -1
134+
// if not found
135+
TokenIndex tok = {.str = str}; // acts as the key to search for
136+
TokenIndex* res = (TokenIndex*)bsearch(
137+
&tok, sorted_vocab, vocab_size, sizeof(TokenIndex), compare_tokens);
138+
return res != nullptr ? res->id : -1;
139+
}
140+
141+
/**
142+
* @brief Encode a string into a sequence of tokens.
143+
*
144+
* @param text The string to be encoded.
145+
* @param bos The number of BOS to prepend to the token list.
146+
* @param eos The number of EOS to append to the token list.
147+
* @param tokens The output tokens.
148+
* @param n_tokens The number of tokens.
149+
* @return std::vector<uint64_t>
150+
*/
151+
std::vector<uint64_t>
152+
BPETokenizer::encode(const std::string& text, int8_t bos, int8_t eos) {
153+
if (!initialized_) {
154+
fprintf(stderr, "Tokenizer not initialized\n");
155+
exit(EXIT_FAILURE);
156+
}
157+
// encode the string text (input) into an upper-bound preallocated tokens[]
158+
// array bos != 0 means prepend the BOS token (=1), eos != 0 means append the
159+
// EOS token (=2)
160+
if (text.empty()) {
161+
fprintf(stderr, "cannot encode empty text\n");
162+
exit(EXIT_FAILURE);
163+
}
164+
165+
// create a temporary buffer that will store merge candidates of always two
166+
// consecutive tokens *2 for concat, +1 for null terminator +2 for UTF8 (in
167+
// case max_token_length is 1)
168+
char* str_buffer = new char[max_token_length_ * 2 + 1 + 2];
169+
size_t str_len = 0;
170+
171+
// start at 0 tokens
172+
std::vector<uint64_t> tokens;
173+
174+
// add optional BOS token, if desired
175+
if (bos > 0) {
176+
while (bos--) {
177+
tokens.push_back(bos_tok_);
178+
}
179+
} else {
180+
fprintf(stderr, "bos %d should be >= 0\n", bos);
181+
exit(EXIT_FAILURE);
182+
}
183+
184+
// add_dummy_prefix is true by default
185+
// so prepend a dummy prefix token to the input string, but only if text != ""
186+
// TODO: pretty sure this isn't correct in the general case but I don't have
187+
// the energy to read more of the sentencepiece code to figure out what it's
188+
// doing
189+
const char* space = " ";
190+
if (text[0] != '\0') {
191+
int dummy_prefix = str_lookup(space, sorted_vocab_.get(), vocab_size_);
192+
tokens.push_back(dummy_prefix);
193+
}
194+
195+
// Okay UTF-8 time. This will get messy. Here is the reference from Uncyclopedia:
196+
// Code point ↔ UTF-8 conversion
197+
// First code point Last code point Byte 1 Byte 2 Byte 3 Byte 4
198+
// U+0000 U+007F 0xxxxxxx
199+
// U+0080 U+07FF 110xxxxx 10xxxxxx
200+
// U+0800 U+FFFF 1110xxxx 10xxxxxx 10xxxxxx
201+
// U+10000 U+10FFFF 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx
202+
203+
// process the raw (UTF-8) byte sequence of the input string
204+
for (const char* c = text.c_str(); *c != '\0'; c++) {
205+
// reset buffer if the current byte is ASCII or a leading byte
206+
// 0xC0 is 11000000, so (*c & 0xC0) keeps the first 2 bits and zeros the
207+
// rest 0x80 is 10000000 in UTF-8, all continuation bytes start with "10" in
208+
// first two bits so in English this is: "if this byte is not a continuation
209+
// byte"
210+
if ((*c & 0xC0) != 0x80) {
211+
// this byte must be either a leading byte (11...) or an ASCII char
212+
// (0x...)
213+
// => reset our location, as we're starting a new UTF-8 codepoint
214+
str_len = 0;
215+
}
216+
217+
// append the current byte to the buffer
218+
str_buffer[str_len++] =
219+
*c; // ++ is post-increment, incremented after this line
220+
str_buffer[str_len] = '\0';
221+
222+
// while the next character is a continuation byte, continue appending
223+
// but if there are too many of them, just stop to avoid overruning
224+
// str_buffer size.
225+
if ((*(c + 1) & 0xC0) == 0x80 && str_len < 4) {
226+
continue;
227+
}
228+
229+
// ok c+1 is not a continuation byte, so we've read in a full codepoint
230+
int id = str_lookup(str_buffer, sorted_vocab_.get(), vocab_size_);
231+
if (id != -1) {
232+
// we found this codepoint in vocab, add it as a token
233+
tokens.push_back(id);
234+
} else {
235+
// byte_fallback encoding: just encode each byte as a token
236+
// +3 is here because the first 3 vocab elements are <unk>, <s>, </s>
237+
// so the individual bytes only start at index 3
238+
for (int i = 0; i < str_len; i++) {
239+
tokens.push_back((unsigned char)str_buffer[i] + 3);
240+
}
241+
}
242+
str_len = 0; // protect against a sequence of stray UTF8 continuation bytes
243+
}
244+
245+
// merge the best consecutive pair each iteration, according the scores in
246+
// vocab_scores
247+
while (1) {
248+
float best_score = -1e10;
249+
int best_id = -1;
250+
int best_idx = -1;
251+
252+
for (int i = 0; i < tokens.size() - 1; i++) {
253+
// check if we can merge the pair (tokens[i], tokens[i+1])
254+
snprintf(
255+
str_buffer,
256+
max_token_length_ * 2 + 3,
257+
"%s%s",
258+
vocab_[tokens[i]],
259+
vocab_[tokens[i + 1]]);
260+
int id = str_lookup(str_buffer, sorted_vocab_.get(), vocab_size_);
261+
if (id != -1 && vocab_scores_[id] > best_score) {
262+
// this merge pair exists in vocab! record its score and position
263+
best_score = vocab_scores_[id];
264+
best_id = id;
265+
best_idx = i;
266+
}
267+
}
268+
269+
if (best_idx == -1) {
270+
break; // we couldn't find any more pairs to merge, so we're done
271+
}
272+
273+
// merge the consecutive pair (best_idx, best_idx+1) into new token best_id
274+
tokens[best_idx] = best_id;
275+
// delete token at position best_idx+1, shift the entire sequence back 1
276+
for (int i = best_idx + 1; i < tokens.size() - 1; i++) {
277+
tokens[i] = tokens[i + 1];
278+
}
279+
tokens.pop_back(); // token length decreased
280+
}
281+
282+
// add optional EOS (=2) token, if desired
283+
if (eos >= 0) {
284+
while (eos--) {
285+
tokens.push_back(eos_tok_);
286+
}
287+
} else {
288+
fprintf(stderr, "eos %d should be >= 0\n", eos);
289+
exit(EXIT_FAILURE);
290+
}
291+
292+
delete[] str_buffer;
293+
return tokens;
294+
}

0 commit comments

Comments
 (0)