|
| 1 | +#include <tokenizer.h> |
| 2 | + |
| 3 | +static int compare_tokens(const void* a, const void* b) { |
| 4 | + if (((TokenIndex*)a)->str == nullptr) { |
| 5 | + return -1; |
| 6 | + } |
| 7 | + if (((TokenIndex*)b)->str == nullptr) { |
| 8 | + return 1; |
| 9 | + } |
| 10 | + return strcmp(((TokenIndex*)a)->str, ((TokenIndex*)b)->str); |
| 11 | +} |
| 12 | + |
| 13 | +BPETokenizer::BPETokenizer( |
| 14 | + int32_t vocab_size, |
| 15 | + uint64_t bos_tok, |
| 16 | + uint64_t eos_tok) |
| 17 | + : Tokenizer(vocab_size, bos_tok, eos_tok), |
| 18 | + vocab_(std::make_unique<char*[]>(vocab_size)), |
| 19 | + vocab_scores_(std::make_unique<float[]>(vocab_size)), |
| 20 | + sorted_vocab_(std::make_unique<TokenIndex[]>(vocab_size)) { |
| 21 | + for (int i = 0; i < 256; i++) { |
| 22 | + byte_pieces_[i * 2] = (unsigned char)i; |
| 23 | + byte_pieces_[i * 2 + 1] = '\0'; |
| 24 | + } |
| 25 | +} |
| 26 | + |
| 27 | +/** |
| 28 | + * @brief Load the tokenizer from a file. The tokenizer file contains the |
| 29 | + * vocabulary and scores. The format is: the first integer is the maximum |
| 30 | + * token length, followed by a list of (word_len, word) pairs. Here we |
| 31 | + * are reading all the vocabulary into memory and keep it sorted for fast |
| 32 | + * lookup. |
| 33 | + * |
| 34 | + * @param tokenizer_path The path to the tokenizer file. |
| 35 | + * @return void |
| 36 | + */ |
| 37 | +void BPETokenizer::load(const std::string& tokenizer_path) { |
| 38 | + if (initialized_) { |
| 39 | + fprintf(stderr, "Tokenizer already initialized.\n"); |
| 40 | + return; |
| 41 | + } |
| 42 | + // read in the file |
| 43 | + FILE* file = fopen(tokenizer_path.c_str(), "rb"); |
| 44 | + if (!file) { |
| 45 | + fprintf(stderr, "couldn't load %s\n", tokenizer_path.c_str()); |
| 46 | + exit(EXIT_FAILURE); |
| 47 | + } |
| 48 | + if (fread(&max_token_length_, sizeof(int32_t), 1, file) != 1) { |
| 49 | + fprintf( |
| 50 | + stderr, |
| 51 | + "Failed to read the max token length, the tokenizer file is not valid!\n"); |
| 52 | + exit(EXIT_FAILURE); |
| 53 | + } |
| 54 | + // allocate space for the vocabulary |
| 55 | + vocab_ = std::make_unique<char*[]>(vocab_size_); |
| 56 | + vocab_scores_ = std::make_unique<float[]>(vocab_size_); |
| 57 | + sorted_vocab_ = std::make_unique<TokenIndex[]>(vocab_size_); |
| 58 | + |
| 59 | + // read in the vocabulary |
| 60 | + for (int i = 0; i < vocab_size_; i++) { |
| 61 | + if (fread(vocab_scores_.get() + i, sizeof(float), 1, file) != 1) { |
| 62 | + // This is allowed, we just pad the rest of the vocab with <pad> strings |
| 63 | + std::string padding = "<pad>"; |
| 64 | + vocab_[i] = new char[padding.length() + 1]; |
| 65 | + strcpy(vocab_[i], padding.c_str()); |
| 66 | + vocab_[i][padding.length()] = '\0'; |
| 67 | + continue; |
| 68 | + } |
| 69 | + int32_t len; |
| 70 | + if (fread(&len, sizeof(int32_t), 1, file) != 1) { |
| 71 | + fprintf(stderr, "Failed to read the length of the word at index %d\n", i); |
| 72 | + exit(EXIT_FAILURE); |
| 73 | + } |
| 74 | + vocab_[i] = new char[len + 1]; |
| 75 | + if (fread(vocab_[i], len, 1, file) != 1) { |
| 76 | + fprintf( |
| 77 | + stderr, |
| 78 | + "Failed to read the word, total length %d, index %d\n", |
| 79 | + len, |
| 80 | + i); |
| 81 | + exit(EXIT_FAILURE); |
| 82 | + } |
| 83 | + vocab_[i][len] = '\0'; // add the string terminating token |
| 84 | + } |
| 85 | + fclose(file); |
| 86 | + |
| 87 | + for (int32_t i = 0; i < vocab_size_; i++) { |
| 88 | + sorted_vocab_[i].str = vocab_[i]; |
| 89 | + sorted_vocab_[i].id = i; |
| 90 | + } |
| 91 | + qsort(sorted_vocab_.get(), vocab_size_, sizeof(TokenIndex), compare_tokens); |
| 92 | + |
| 93 | + initialized_ = true; |
| 94 | +} |
| 95 | + |
| 96 | +BPETokenizer::~BPETokenizer() { |
| 97 | + for (int i = 0; i < vocab_size_; i++) { |
| 98 | + delete[] vocab_[i]; |
| 99 | + } |
| 100 | +} |
| 101 | + |
| 102 | +/** |
| 103 | + * @brief Decode a token into string. |
| 104 | + * |
| 105 | + * @param prev_token The previous token. |
| 106 | + * @param token The current token. |
| 107 | + * @return std::string A pointer to the string representation of the |
| 108 | + * token. |
| 109 | + */ |
| 110 | +std::string BPETokenizer::decode(uint64_t prev_token, uint64_t token) { |
| 111 | + if (!initialized_) { |
| 112 | + fprintf(stderr, "Tokenizer not initialized\n"); |
| 113 | + exit(EXIT_FAILURE); |
| 114 | + } |
| 115 | + const char* piece = vocab_[token]; |
| 116 | + // following BOS token, sentencepiece decoder strips any leading |
| 117 | + // whitespace |
| 118 | + if (prev_token == bos_tok_ && piece[0] == ' ') { |
| 119 | + piece++; |
| 120 | + } |
| 121 | + // careful, some tokens designate raw bytes, and look like e.g. '<0x01>' |
| 122 | + // parse this and convert and return the actual byte |
| 123 | + unsigned char byte_val; |
| 124 | + if (sscanf(piece, "<0x%02hhX>", &byte_val) == 1) { |
| 125 | + piece = (char*)byte_pieces_ + byte_val * 2; |
| 126 | + } |
| 127 | + std::string res(piece); |
| 128 | + return res; |
| 129 | +} |
| 130 | + |
| 131 | +static int32_t |
| 132 | +str_lookup(const char* str, TokenIndex* sorted_vocab, int32_t vocab_size) { |
| 133 | + // efficiently find the perfect match for str in vocab, return its index or -1 |
| 134 | + // if not found |
| 135 | + TokenIndex tok = {.str = str}; // acts as the key to search for |
| 136 | + TokenIndex* res = (TokenIndex*)bsearch( |
| 137 | + &tok, sorted_vocab, vocab_size, sizeof(TokenIndex), compare_tokens); |
| 138 | + return res != nullptr ? res->id : -1; |
| 139 | +} |
| 140 | + |
| 141 | +/** |
| 142 | + * @brief Encode a string into a sequence of tokens. |
| 143 | + * |
| 144 | + * @param text The string to be encoded. |
| 145 | + * @param bos The number of BOS to prepend to the token list. |
| 146 | + * @param eos The number of EOS to append to the token list. |
| 147 | + * @param tokens The output tokens. |
| 148 | + * @param n_tokens The number of tokens. |
| 149 | + * @return std::vector<uint64_t> |
| 150 | + */ |
| 151 | +std::vector<uint64_t> |
| 152 | +BPETokenizer::encode(const std::string& text, int8_t bos, int8_t eos) { |
| 153 | + if (!initialized_) { |
| 154 | + fprintf(stderr, "Tokenizer not initialized\n"); |
| 155 | + exit(EXIT_FAILURE); |
| 156 | + } |
| 157 | + // encode the string text (input) into an upper-bound preallocated tokens[] |
| 158 | + // array bos != 0 means prepend the BOS token (=1), eos != 0 means append the |
| 159 | + // EOS token (=2) |
| 160 | + if (text.empty()) { |
| 161 | + fprintf(stderr, "cannot encode empty text\n"); |
| 162 | + exit(EXIT_FAILURE); |
| 163 | + } |
| 164 | + |
| 165 | + // create a temporary buffer that will store merge candidates of always two |
| 166 | + // consecutive tokens *2 for concat, +1 for null terminator +2 for UTF8 (in |
| 167 | + // case max_token_length is 1) |
| 168 | + char* str_buffer = new char[max_token_length_ * 2 + 1 + 2]; |
| 169 | + size_t str_len = 0; |
| 170 | + |
| 171 | + // start at 0 tokens |
| 172 | + std::vector<uint64_t> tokens; |
| 173 | + |
| 174 | + // add optional BOS token, if desired |
| 175 | + if (bos > 0) { |
| 176 | + while (bos--) { |
| 177 | + tokens.push_back(bos_tok_); |
| 178 | + } |
| 179 | + } else { |
| 180 | + fprintf(stderr, "bos %d should be >= 0\n", bos); |
| 181 | + exit(EXIT_FAILURE); |
| 182 | + } |
| 183 | + |
| 184 | + // add_dummy_prefix is true by default |
| 185 | + // so prepend a dummy prefix token to the input string, but only if text != "" |
| 186 | + // TODO: pretty sure this isn't correct in the general case but I don't have |
| 187 | + // the energy to read more of the sentencepiece code to figure out what it's |
| 188 | + // doing |
| 189 | + const char* space = " "; |
| 190 | + if (text[0] != '\0') { |
| 191 | + int dummy_prefix = str_lookup(space, sorted_vocab_.get(), vocab_size_); |
| 192 | + tokens.push_back(dummy_prefix); |
| 193 | + } |
| 194 | + |
| 195 | + // Okay UTF-8 time. This will get messy. Here is the reference from Uncyclopedia: |
| 196 | + // Code point ↔ UTF-8 conversion |
| 197 | + // First code point Last code point Byte 1 Byte 2 Byte 3 Byte 4 |
| 198 | + // U+0000 U+007F 0xxxxxxx |
| 199 | + // U+0080 U+07FF 110xxxxx 10xxxxxx |
| 200 | + // U+0800 U+FFFF 1110xxxx 10xxxxxx 10xxxxxx |
| 201 | + // U+10000 U+10FFFF 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx |
| 202 | + |
| 203 | + // process the raw (UTF-8) byte sequence of the input string |
| 204 | + for (const char* c = text.c_str(); *c != '\0'; c++) { |
| 205 | + // reset buffer if the current byte is ASCII or a leading byte |
| 206 | + // 0xC0 is 11000000, so (*c & 0xC0) keeps the first 2 bits and zeros the |
| 207 | + // rest 0x80 is 10000000 in UTF-8, all continuation bytes start with "10" in |
| 208 | + // first two bits so in English this is: "if this byte is not a continuation |
| 209 | + // byte" |
| 210 | + if ((*c & 0xC0) != 0x80) { |
| 211 | + // this byte must be either a leading byte (11...) or an ASCII char |
| 212 | + // (0x...) |
| 213 | + // => reset our location, as we're starting a new UTF-8 codepoint |
| 214 | + str_len = 0; |
| 215 | + } |
| 216 | + |
| 217 | + // append the current byte to the buffer |
| 218 | + str_buffer[str_len++] = |
| 219 | + *c; // ++ is post-increment, incremented after this line |
| 220 | + str_buffer[str_len] = '\0'; |
| 221 | + |
| 222 | + // while the next character is a continuation byte, continue appending |
| 223 | + // but if there are too many of them, just stop to avoid overruning |
| 224 | + // str_buffer size. |
| 225 | + if ((*(c + 1) & 0xC0) == 0x80 && str_len < 4) { |
| 226 | + continue; |
| 227 | + } |
| 228 | + |
| 229 | + // ok c+1 is not a continuation byte, so we've read in a full codepoint |
| 230 | + int id = str_lookup(str_buffer, sorted_vocab_.get(), vocab_size_); |
| 231 | + if (id != -1) { |
| 232 | + // we found this codepoint in vocab, add it as a token |
| 233 | + tokens.push_back(id); |
| 234 | + } else { |
| 235 | + // byte_fallback encoding: just encode each byte as a token |
| 236 | + // +3 is here because the first 3 vocab elements are <unk>, <s>, </s> |
| 237 | + // so the individual bytes only start at index 3 |
| 238 | + for (int i = 0; i < str_len; i++) { |
| 239 | + tokens.push_back((unsigned char)str_buffer[i] + 3); |
| 240 | + } |
| 241 | + } |
| 242 | + str_len = 0; // protect against a sequence of stray UTF8 continuation bytes |
| 243 | + } |
| 244 | + |
| 245 | + // merge the best consecutive pair each iteration, according the scores in |
| 246 | + // vocab_scores |
| 247 | + while (1) { |
| 248 | + float best_score = -1e10; |
| 249 | + int best_id = -1; |
| 250 | + int best_idx = -1; |
| 251 | + |
| 252 | + for (int i = 0; i < tokens.size() - 1; i++) { |
| 253 | + // check if we can merge the pair (tokens[i], tokens[i+1]) |
| 254 | + snprintf( |
| 255 | + str_buffer, |
| 256 | + max_token_length_ * 2 + 3, |
| 257 | + "%s%s", |
| 258 | + vocab_[tokens[i]], |
| 259 | + vocab_[tokens[i + 1]]); |
| 260 | + int id = str_lookup(str_buffer, sorted_vocab_.get(), vocab_size_); |
| 261 | + if (id != -1 && vocab_scores_[id] > best_score) { |
| 262 | + // this merge pair exists in vocab! record its score and position |
| 263 | + best_score = vocab_scores_[id]; |
| 264 | + best_id = id; |
| 265 | + best_idx = i; |
| 266 | + } |
| 267 | + } |
| 268 | + |
| 269 | + if (best_idx == -1) { |
| 270 | + break; // we couldn't find any more pairs to merge, so we're done |
| 271 | + } |
| 272 | + |
| 273 | + // merge the consecutive pair (best_idx, best_idx+1) into new token best_id |
| 274 | + tokens[best_idx] = best_id; |
| 275 | + // delete token at position best_idx+1, shift the entire sequence back 1 |
| 276 | + for (int i = best_idx + 1; i < tokens.size() - 1; i++) { |
| 277 | + tokens[i] = tokens[i + 1]; |
| 278 | + } |
| 279 | + tokens.pop_back(); // token length decreased |
| 280 | + } |
| 281 | + |
| 282 | + // add optional EOS (=2) token, if desired |
| 283 | + if (eos >= 0) { |
| 284 | + while (eos--) { |
| 285 | + tokens.push_back(eos_tok_); |
| 286 | + } |
| 287 | + } else { |
| 288 | + fprintf(stderr, "eos %d should be >= 0\n", eos); |
| 289 | + exit(EXIT_FAILURE); |
| 290 | + } |
| 291 | + |
| 292 | + delete[] str_buffer; |
| 293 | + return tokens; |
| 294 | +} |
0 commit comments