|
| 1 | +/* |
| 2 | + * Copyright (c) Meta Platforms, Inc. and affiliates. |
| 3 | + * All rights reserved. |
| 4 | + * |
| 5 | + * This source code is licensed under the BSD-style license found in the |
| 6 | + * LICENSE file in the root directory of this source tree. |
| 7 | + */ |
| 8 | +// @lint-ignore-every CLANGTIDY facebook-hte-RelativeInclude |
| 9 | +#include "llama2c_tokenizer.h" |
| 10 | +#include <cstring> |
| 11 | + |
| 12 | +namespace tokenizers { |
| 13 | + |
| 14 | +static int compare_tokens(const void* a, const void* b) { |
| 15 | + if (((TokenIndex*)a)->str == nullptr) { |
| 16 | + return -1; |
| 17 | + } |
| 18 | + if (((TokenIndex*)b)->str == nullptr) { |
| 19 | + return 1; |
| 20 | + } |
| 21 | + return strcmp(((TokenIndex*)a)->str, ((TokenIndex*)b)->str); |
| 22 | +} |
| 23 | + |
| 24 | +Llama2cTokenizer::Llama2cTokenizer() : Tokenizer() { |
| 25 | + for (int i = 0; i < 256; i++) { |
| 26 | + byte_pieces_[i * 2] = (unsigned char)i; |
| 27 | + byte_pieces_[i * 2 + 1] = '\0'; |
| 28 | + } |
| 29 | +} |
| 30 | + |
| 31 | +/** |
| 32 | + * @brief Load the tokenizer from a file. The tokenizer file contains the |
| 33 | + * vocabulary and scores. The format is: the first integer is the maximum |
| 34 | + * token length, followed by a list of (word_len, word) pairs. Here we |
| 35 | + * are reading all the vocabulary into memory and keep it sorted for fast |
| 36 | + * lookup. |
| 37 | + * |
| 38 | + * @param tokenizer_path The path to the tokenizer file. |
| 39 | + * @return Error |
| 40 | + */ |
| 41 | +Error Llama2cTokenizer::load(const std::string& tokenizer_path) { |
| 42 | + if (initialized_) { |
| 43 | + TK_LOG(Info, "Tokenizer already initialized"); |
| 44 | + return Error::Ok; |
| 45 | + } |
| 46 | + // read in the file |
| 47 | + FILE* file = fopen(tokenizer_path.c_str(), "rb"); |
| 48 | + if (!file) { |
| 49 | + TK_LOG(Error, "couldn't load %s", tokenizer_path.c_str()); |
| 50 | + return Error::LoadFailure; |
| 51 | + } |
| 52 | + int32_t metadata[4]; |
| 53 | + for (int i = 0; i < 4; i++) { |
| 54 | + if (fread(metadata + i, sizeof(int32_t), 1, file) != 1) { |
| 55 | + TK_LOG( |
| 56 | + Error, |
| 57 | + "Failed to read the metadata at position %d, the tokenizer file is not valid!", |
| 58 | + i); |
| 59 | + return Error::ParseFailure; |
| 60 | + } |
| 61 | + } |
| 62 | + |
| 63 | + // now we have two vocab_sizes one from the model and another from the |
| 64 | + // tokenizer file. |
| 65 | + int32_t tokenizer_vocab_size = metadata[0]; |
| 66 | + vocab_size_ = tokenizer_vocab_size; |
| 67 | + bos_tok_ = metadata[1]; |
| 68 | + eos_tok_ = metadata[2]; |
| 69 | + max_token_length_ = metadata[3]; |
| 70 | + |
| 71 | + // allocate space for the vocabulary |
| 72 | + vocab_ = std::make_unique<char*[]>(vocab_size_); |
| 73 | + vocab_scores_ = std::make_unique<float[]>(vocab_size_); |
| 74 | + sorted_vocab_ = std::make_unique<TokenIndex[]>(vocab_size_); |
| 75 | + |
| 76 | + // read in the vocabulary |
| 77 | + for (int i = 0; i < vocab_size_; i++) { |
| 78 | + if (fread(vocab_scores_.get() + i, sizeof(float), 1, file) != 1) { |
| 79 | + // This is allowed, we just pad the rest of the vocab with <pad> strings |
| 80 | + std::string padding = "<pad>"; |
| 81 | + vocab_[i] = new char[padding.length() + 1]; |
| 82 | + strcpy(vocab_[i], padding.c_str()); |
| 83 | + vocab_[i][padding.length()] = '\0'; |
| 84 | + continue; |
| 85 | + } |
| 86 | + int32_t len; |
| 87 | + if (fread(&len, sizeof(int32_t), 1, file) != 1) { |
| 88 | + TK_LOG(Error, "Failed to read the length of the word at index %d", i); |
| 89 | + return Error::ParseFailure; |
| 90 | + } |
| 91 | + vocab_[i] = new char[len + 1]; |
| 92 | + if (fread(vocab_[i], len, 1, file) != 1) { |
| 93 | + TK_LOG( |
| 94 | + Error, |
| 95 | + "Failed to read the word, total length %d, index %d\n", |
| 96 | + len, |
| 97 | + i); |
| 98 | + return Error::ParseFailure; |
| 99 | + } |
| 100 | + vocab_[i][len] = '\0'; // add the string terminating token |
| 101 | + } |
| 102 | + fclose(file); |
| 103 | + |
| 104 | + for (int32_t i = 0; i < vocab_size_; i++) { |
| 105 | + sorted_vocab_[i].str = vocab_[i]; |
| 106 | + sorted_vocab_[i].id = i; |
| 107 | + } |
| 108 | + qsort(sorted_vocab_.get(), vocab_size_, sizeof(TokenIndex), compare_tokens); |
| 109 | + |
| 110 | + initialized_ = true; |
| 111 | + return Error::Ok; |
| 112 | +} |
| 113 | + |
| 114 | +Llama2cTokenizer::~Llama2cTokenizer() { |
| 115 | + for (int i = 0; i < vocab_size_; i++) { |
| 116 | + delete[] vocab_[i]; |
| 117 | + } |
| 118 | +} |
| 119 | + |
| 120 | +/** |
| 121 | + * @brief Decode a token into string. |
| 122 | + * |
| 123 | + * @param prev_token The previous token. |
| 124 | + * @param token The current token. |
| 125 | + * @return Result<std::string> A pointer to the string representation of the |
| 126 | + * token. |
| 127 | + */ |
| 128 | +Result<std::string> Llama2cTokenizer::decode( |
| 129 | + uint64_t prev_token, |
| 130 | + uint64_t token) const { |
| 131 | + TK_CHECK_OK_OR_RETURN_ERROR(Tokenizer::decode_verify(token)); |
| 132 | + const char* piece = vocab_[token]; |
| 133 | + // following BOS token, sentencepiece decoder strips any leading |
| 134 | + // whitespace |
| 135 | + if (prev_token == bos_tok_ && piece[0] == ' ') { |
| 136 | + piece++; |
| 137 | + } |
| 138 | + // careful, some tokens designate raw bytes, and look like e.g. '<0x01>' |
| 139 | + // parse this and convert and return the actual byte |
| 140 | + unsigned char byte_val; |
| 141 | + if (sscanf(piece, "<0x%02hhX>", &byte_val) == 1) { |
| 142 | + piece = (char*)byte_pieces_ + byte_val * 2; |
| 143 | + } |
| 144 | + std::string res(piece); |
| 145 | + return res; |
| 146 | +} |
| 147 | + |
| 148 | +static int32_t |
| 149 | +str_lookup(const char* str, TokenIndex* sorted_vocab, int32_t vocab_size) { |
| 150 | + // efficiently find the perfect match for str in vocab, return its index or -1 |
| 151 | + // if not found |
| 152 | + TokenIndex tok = {.str = str}; // acts as the key to search for |
| 153 | + TokenIndex* res = (TokenIndex*)bsearch( |
| 154 | + &tok, sorted_vocab, vocab_size, sizeof(TokenIndex), compare_tokens); |
| 155 | + return res != nullptr ? res->id : -1; |
| 156 | +} |
| 157 | + |
| 158 | +/** |
| 159 | + * @brief Encode a string into a sequence of tokens. |
| 160 | + * |
| 161 | + * @param text The string to be encoded. |
| 162 | + * @param bos The number of BOS to prepend to the token list. |
| 163 | + * @param eos The number of EOS to append to the token list. |
| 164 | + * @param tokens The output tokens. |
| 165 | + * @param n_tokens The number of tokens. |
| 166 | + * @return Result<std::vector<uint64_t>> |
| 167 | + */ |
| 168 | +Result<std::vector<uint64_t>> Llama2cTokenizer::encode( |
| 169 | + const std::string& text, |
| 170 | + int8_t bos, |
| 171 | + int8_t eos) const { |
| 172 | + if (!initialized_) { |
| 173 | + TK_LOG(Error, "Tokenizer not initialized"); |
| 174 | + return Error::Uninitialized; |
| 175 | + } |
| 176 | + // encode the string text (input) into an upper-bound preallocated tokens[] |
| 177 | + // array bos != 0 means prepend the BOS token (=1), eos != 0 means append the |
| 178 | + // EOS token (=2) |
| 179 | + if (text.empty()) { |
| 180 | + TK_LOG(Error, "cannot encode empty text"); |
| 181 | + return Error::EncodeFailure; |
| 182 | + } |
| 183 | + |
| 184 | + // create a temporary buffer that will store merge candidates of always two |
| 185 | + // consecutive tokens *2 for concat, +1 for null terminator +2 for UTF8 (in |
| 186 | + // case max_token_length is 1) |
| 187 | + char* str_buffer = new char[max_token_length_ * 2 + 1 + 2]; |
| 188 | + size_t str_len = 0; |
| 189 | + |
| 190 | + // start at 0 tokens |
| 191 | + std::vector<uint64_t> tokens; |
| 192 | + |
| 193 | + // add optional BOS token, if desired |
| 194 | + if (bos >= 0) { |
| 195 | + while (bos--) { |
| 196 | + tokens.push_back(bos_tok_); |
| 197 | + } |
| 198 | + } else { |
| 199 | + TK_LOG(Error, "bos %d should be >= 0", bos); |
| 200 | + return Error::EncodeFailure; |
| 201 | + } |
| 202 | + |
| 203 | + // add_dummy_prefix is true by default |
| 204 | + // so prepend a dummy prefix token to the input string, but only if text != "" |
| 205 | + // TODO: pretty sure this isn't correct in the general case but I don't have |
| 206 | + // the energy to read more of the sentencepiece code to figure out what it's |
| 207 | + // doing |
| 208 | + const char* space = " "; |
| 209 | + if (text[0] != '\0') { |
| 210 | + int dummy_prefix = str_lookup(space, sorted_vocab_.get(), vocab_size_); |
| 211 | + tokens.push_back(dummy_prefix); |
| 212 | + } |
| 213 | + |
| 214 | + // Okay UTF-8 time. This will get messy. Here is the reference from Uncyclopedia: |
| 215 | + // Code point ↔ UTF-8 conversion |
| 216 | + // First code point Last code point Byte 1 Byte 2 Byte 3 Byte 4 |
| 217 | + // U+0000 U+007F 0xxxxxxx |
| 218 | + // U+0080 U+07FF 110xxxxx 10xxxxxx |
| 219 | + // U+0800 U+FFFF 1110xxxx 10xxxxxx 10xxxxxx |
| 220 | + // U+10000 U+10FFFF 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx |
| 221 | + |
| 222 | + // process the raw (UTF-8) byte sequence of the input string |
| 223 | + for (const char* c = text.c_str(); *c != '\0'; c++) { |
| 224 | + // reset buffer if the current byte is ASCII or a leading byte |
| 225 | + // 0xC0 is 11000000, so (*c & 0xC0) keeps the first 2 bits and zeros the |
| 226 | + // rest 0x80 is 10000000 in UTF-8, all continuation bytes start with "10" in |
| 227 | + // first two bits so in English this is: "if this byte is not a continuation |
| 228 | + // byte" |
| 229 | + if ((*c & 0xC0) != 0x80) { |
| 230 | + // this byte must be either a leading byte (11...) or an ASCII char |
| 231 | + // (0x...) |
| 232 | + // => reset our location, as we're starting a new UTF-8 codepoint |
| 233 | + str_len = 0; |
| 234 | + } |
| 235 | + |
| 236 | + // append the current byte to the buffer |
| 237 | + str_buffer[str_len++] = |
| 238 | + *c; // ++ is post-increment, incremented after this line |
| 239 | + str_buffer[str_len] = '\0'; |
| 240 | + |
| 241 | + // while the next character is a continuation byte, continue appending |
| 242 | + // but if there are too many of them, just stop to avoid overruning |
| 243 | + // str_buffer size. |
| 244 | + if ((*(c + 1) & 0xC0) == 0x80 && str_len < 4) { |
| 245 | + continue; |
| 246 | + } |
| 247 | + |
| 248 | + // ok c+1 is not a continuation byte, so we've read in a full codepoint |
| 249 | + int id = str_lookup(str_buffer, sorted_vocab_.get(), vocab_size_); |
| 250 | + if (id != -1) { |
| 251 | + // we found this codepoint in vocab, add it as a token |
| 252 | + tokens.push_back(id); |
| 253 | + } else { |
| 254 | + // byte_fallback encoding: just encode each byte as a token |
| 255 | + // +3 is here because the first 3 vocab elements are <unk>, <s>, </s> |
| 256 | + // so the individual bytes only start at index 3 |
| 257 | + for (int i = 0; i < str_len; i++) { |
| 258 | + tokens.push_back((unsigned char)str_buffer[i] + 3); |
| 259 | + } |
| 260 | + } |
| 261 | + str_len = 0; // protect against a sequence of stray UTF8 continuation bytes |
| 262 | + } |
| 263 | + |
| 264 | + // merge the best consecutive pair each iteration, according the scores in |
| 265 | + // vocab_scores |
| 266 | + while (1) { |
| 267 | + float best_score = -1e10; |
| 268 | + int best_id = -1; |
| 269 | + int best_idx = -1; |
| 270 | + |
| 271 | + for (int i = 0; i < tokens.size() - 1; i++) { |
| 272 | + // check if we can merge the pair (tokens[i], tokens[i+1]) |
| 273 | + snprintf( |
| 274 | + str_buffer, |
| 275 | + max_token_length_ * 2 + 3, |
| 276 | + "%s%s", |
| 277 | + vocab_[tokens[i]], |
| 278 | + vocab_[tokens[i + 1]]); |
| 279 | + int id = str_lookup(str_buffer, sorted_vocab_.get(), vocab_size_); |
| 280 | + if (id != -1 && vocab_scores_[id] > best_score) { |
| 281 | + // this merge pair exists in vocab! record its score and position |
| 282 | + best_score = vocab_scores_[id]; |
| 283 | + best_id = id; |
| 284 | + best_idx = i; |
| 285 | + } |
| 286 | + } |
| 287 | + |
| 288 | + if (best_idx == -1) { |
| 289 | + break; // we couldn't find any more pairs to merge, so we're done |
| 290 | + } |
| 291 | + |
| 292 | + // merge the consecutive pair (best_idx, best_idx+1) into new token best_id |
| 293 | + tokens[best_idx] = best_id; |
| 294 | + // delete token at position best_idx+1, shift the entire sequence back 1 |
| 295 | + for (int i = best_idx + 1; i < tokens.size() - 1; i++) { |
| 296 | + tokens[i] = tokens[i + 1]; |
| 297 | + } |
| 298 | + tokens.pop_back(); // token length decreased |
| 299 | + } |
| 300 | + |
| 301 | + // add optional EOS (=2) token, if desired |
| 302 | + if (eos >= 0) { |
| 303 | + while (eos--) { |
| 304 | + tokens.push_back(eos_tok_); |
| 305 | + } |
| 306 | + } else { |
| 307 | + TK_LOG(Error, "eos %d should be >= 0", eos); |
| 308 | + return Error::EncodeFailure; |
| 309 | + } |
| 310 | + |
| 311 | + delete[] str_buffer; |
| 312 | + return Result(tokens); |
| 313 | +} |
| 314 | + |
| 315 | +} // namespace tokenizers |
0 commit comments