Skip to content

Commit bba6759

Browse files
authored
Add llama2.c tokenizers
Differential Revision: D69579081 Pull Request resolved: #19
1 parent 03744ce commit bba6759

File tree

10 files changed

+656
-1
lines changed

10 files changed

+656
-1
lines changed

README.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,13 @@ C++ implementations for various tokenizers (sentencepiece, tiktoken etc). Useful
66
Depend on https://github.com/google/sentencepiece from Google.
77

88
## Tiktoken tokenizer
9-
Adopted from https://github.com/sewenew/tokenizer.
9+
Adapted from https://github.com/sewenew/tokenizer.
10+
11+
## Huggingface tokenizer
12+
Compatible with https://github.com/huggingface/tokenizers/.
13+
14+
## Llama2.c tokenizer
15+
Adapted from https://github.com/karpathy/llama2.c.
1016

1117
## License
1218

include/llama2c_tokenizer.h

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
// @lint-ignore-every CLANGTIDY facebook-hte-RelativeInclude
9+
#pragma once
10+
#include <memory>
11+
#include "tokenizer.h"
12+
13+
namespace tokenizers {
14+
15+
struct TokenIndex {
16+
const char* str;
17+
int32_t id;
18+
};
19+
20+
// A simple Byte Pair Encoding (BPE) Tokenizer. Note that the current C++ code
21+
// won't work with this class, it needs to go through tokenizer.py first.
22+
class Llama2cTokenizer : public Tokenizer {
23+
public:
24+
explicit Llama2cTokenizer();
25+
~Llama2cTokenizer() override;
26+
27+
Error load(const std::string& tokenizer_path) override;
28+
29+
Result<std::vector<uint64_t>>
30+
encode(const std::string& input, int8_t bos, int8_t eos) const override;
31+
32+
Result<std::string> decode(uint64_t prev_token, uint64_t token)
33+
const override;
34+
35+
private:
36+
std::unique_ptr<char*[]> vocab_ = nullptr;
37+
std::unique_ptr<float[]> vocab_scores_ = nullptr;
38+
std::unique_ptr<TokenIndex[]> sorted_vocab_ = nullptr;
39+
unsigned int max_token_length_ = 0;
40+
unsigned char byte_pieces_[512]; // stores all single-byte strings
41+
};
42+
43+
} // namespace tokenizers

src/llama2c_tokenizer.cpp

Lines changed: 315 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,315 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
// @lint-ignore-every CLANGTIDY facebook-hte-RelativeInclude
9+
#include "llama2c_tokenizer.h"
10+
#include <cstring>
11+
12+
namespace tokenizers {
13+
14+
static int compare_tokens(const void* a, const void* b) {
15+
if (((TokenIndex*)a)->str == nullptr) {
16+
return -1;
17+
}
18+
if (((TokenIndex*)b)->str == nullptr) {
19+
return 1;
20+
}
21+
return strcmp(((TokenIndex*)a)->str, ((TokenIndex*)b)->str);
22+
}
23+
24+
Llama2cTokenizer::Llama2cTokenizer() : Tokenizer() {
25+
for (int i = 0; i < 256; i++) {
26+
byte_pieces_[i * 2] = (unsigned char)i;
27+
byte_pieces_[i * 2 + 1] = '\0';
28+
}
29+
}
30+
31+
/**
32+
* @brief Load the tokenizer from a file. The tokenizer file contains the
33+
* vocabulary and scores. The format is: the first integer is the maximum
34+
* token length, followed by a list of (word_len, word) pairs. Here we
35+
* are reading all the vocabulary into memory and keep it sorted for fast
36+
* lookup.
37+
*
38+
* @param tokenizer_path The path to the tokenizer file.
39+
* @return Error
40+
*/
41+
Error Llama2cTokenizer::load(const std::string& tokenizer_path) {
42+
if (initialized_) {
43+
TK_LOG(Info, "Tokenizer already initialized");
44+
return Error::Ok;
45+
}
46+
// read in the file
47+
FILE* file = fopen(tokenizer_path.c_str(), "rb");
48+
if (!file) {
49+
TK_LOG(Error, "couldn't load %s", tokenizer_path.c_str());
50+
return Error::LoadFailure;
51+
}
52+
int32_t metadata[4];
53+
for (int i = 0; i < 4; i++) {
54+
if (fread(metadata + i, sizeof(int32_t), 1, file) != 1) {
55+
TK_LOG(
56+
Error,
57+
"Failed to read the metadata at position %d, the tokenizer file is not valid!",
58+
i);
59+
return Error::ParseFailure;
60+
}
61+
}
62+
63+
// now we have two vocab_sizes one from the model and another from the
64+
// tokenizer file.
65+
int32_t tokenizer_vocab_size = metadata[0];
66+
vocab_size_ = tokenizer_vocab_size;
67+
bos_tok_ = metadata[1];
68+
eos_tok_ = metadata[2];
69+
max_token_length_ = metadata[3];
70+
71+
// allocate space for the vocabulary
72+
vocab_ = std::make_unique<char*[]>(vocab_size_);
73+
vocab_scores_ = std::make_unique<float[]>(vocab_size_);
74+
sorted_vocab_ = std::make_unique<TokenIndex[]>(vocab_size_);
75+
76+
// read in the vocabulary
77+
for (int i = 0; i < vocab_size_; i++) {
78+
if (fread(vocab_scores_.get() + i, sizeof(float), 1, file) != 1) {
79+
// This is allowed, we just pad the rest of the vocab with <pad> strings
80+
std::string padding = "<pad>";
81+
vocab_[i] = new char[padding.length() + 1];
82+
strcpy(vocab_[i], padding.c_str());
83+
vocab_[i][padding.length()] = '\0';
84+
continue;
85+
}
86+
int32_t len;
87+
if (fread(&len, sizeof(int32_t), 1, file) != 1) {
88+
TK_LOG(Error, "Failed to read the length of the word at index %d", i);
89+
return Error::ParseFailure;
90+
}
91+
vocab_[i] = new char[len + 1];
92+
if (fread(vocab_[i], len, 1, file) != 1) {
93+
TK_LOG(
94+
Error,
95+
"Failed to read the word, total length %d, index %d\n",
96+
len,
97+
i);
98+
return Error::ParseFailure;
99+
}
100+
vocab_[i][len] = '\0'; // add the string terminating token
101+
}
102+
fclose(file);
103+
104+
for (int32_t i = 0; i < vocab_size_; i++) {
105+
sorted_vocab_[i].str = vocab_[i];
106+
sorted_vocab_[i].id = i;
107+
}
108+
qsort(sorted_vocab_.get(), vocab_size_, sizeof(TokenIndex), compare_tokens);
109+
110+
initialized_ = true;
111+
return Error::Ok;
112+
}
113+
114+
Llama2cTokenizer::~Llama2cTokenizer() {
115+
for (int i = 0; i < vocab_size_; i++) {
116+
delete[] vocab_[i];
117+
}
118+
}
119+
120+
/**
121+
* @brief Decode a token into string.
122+
*
123+
* @param prev_token The previous token.
124+
* @param token The current token.
125+
* @return Result<std::string> A pointer to the string representation of the
126+
* token.
127+
*/
128+
Result<std::string> Llama2cTokenizer::decode(
129+
uint64_t prev_token,
130+
uint64_t token) const {
131+
TK_CHECK_OK_OR_RETURN_ERROR(Tokenizer::decode_verify(token));
132+
const char* piece = vocab_[token];
133+
// following BOS token, sentencepiece decoder strips any leading
134+
// whitespace
135+
if (prev_token == bos_tok_ && piece[0] == ' ') {
136+
piece++;
137+
}
138+
// careful, some tokens designate raw bytes, and look like e.g. '<0x01>'
139+
// parse this and convert and return the actual byte
140+
unsigned char byte_val;
141+
if (sscanf(piece, "<0x%02hhX>", &byte_val) == 1) {
142+
piece = (char*)byte_pieces_ + byte_val * 2;
143+
}
144+
std::string res(piece);
145+
return res;
146+
}
147+
148+
static int32_t
149+
str_lookup(const char* str, TokenIndex* sorted_vocab, int32_t vocab_size) {
150+
// efficiently find the perfect match for str in vocab, return its index or -1
151+
// if not found
152+
TokenIndex tok = {.str = str}; // acts as the key to search for
153+
TokenIndex* res = (TokenIndex*)bsearch(
154+
&tok, sorted_vocab, vocab_size, sizeof(TokenIndex), compare_tokens);
155+
return res != nullptr ? res->id : -1;
156+
}
157+
158+
/**
159+
* @brief Encode a string into a sequence of tokens.
160+
*
161+
* @param text The string to be encoded.
162+
* @param bos The number of BOS to prepend to the token list.
163+
* @param eos The number of EOS to append to the token list.
164+
* @param tokens The output tokens.
165+
* @param n_tokens The number of tokens.
166+
* @return Result<std::vector<uint64_t>>
167+
*/
168+
Result<std::vector<uint64_t>> Llama2cTokenizer::encode(
169+
const std::string& text,
170+
int8_t bos,
171+
int8_t eos) const {
172+
if (!initialized_) {
173+
TK_LOG(Error, "Tokenizer not initialized");
174+
return Error::Uninitialized;
175+
}
176+
// encode the string text (input) into an upper-bound preallocated tokens[]
177+
// array bos != 0 means prepend the BOS token (=1), eos != 0 means append the
178+
// EOS token (=2)
179+
if (text.empty()) {
180+
TK_LOG(Error, "cannot encode empty text");
181+
return Error::EncodeFailure;
182+
}
183+
184+
// create a temporary buffer that will store merge candidates of always two
185+
// consecutive tokens *2 for concat, +1 for null terminator +2 for UTF8 (in
186+
// case max_token_length is 1)
187+
char* str_buffer = new char[max_token_length_ * 2 + 1 + 2];
188+
size_t str_len = 0;
189+
190+
// start at 0 tokens
191+
std::vector<uint64_t> tokens;
192+
193+
// add optional BOS token, if desired
194+
if (bos >= 0) {
195+
while (bos--) {
196+
tokens.push_back(bos_tok_);
197+
}
198+
} else {
199+
TK_LOG(Error, "bos %d should be >= 0", bos);
200+
return Error::EncodeFailure;
201+
}
202+
203+
// add_dummy_prefix is true by default
204+
// so prepend a dummy prefix token to the input string, but only if text != ""
205+
// TODO: pretty sure this isn't correct in the general case but I don't have
206+
// the energy to read more of the sentencepiece code to figure out what it's
207+
// doing
208+
const char* space = " ";
209+
if (text[0] != '\0') {
210+
int dummy_prefix = str_lookup(space, sorted_vocab_.get(), vocab_size_);
211+
tokens.push_back(dummy_prefix);
212+
}
213+
214+
// Okay UTF-8 time. This will get messy. Here is the reference from Uncyclopedia:
215+
// Code point ↔ UTF-8 conversion
216+
// First code point Last code point Byte 1 Byte 2 Byte 3 Byte 4
217+
// U+0000 U+007F 0xxxxxxx
218+
// U+0080 U+07FF 110xxxxx 10xxxxxx
219+
// U+0800 U+FFFF 1110xxxx 10xxxxxx 10xxxxxx
220+
// U+10000 U+10FFFF 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx
221+
222+
// process the raw (UTF-8) byte sequence of the input string
223+
for (const char* c = text.c_str(); *c != '\0'; c++) {
224+
// reset buffer if the current byte is ASCII or a leading byte
225+
// 0xC0 is 11000000, so (*c & 0xC0) keeps the first 2 bits and zeros the
226+
// rest 0x80 is 10000000 in UTF-8, all continuation bytes start with "10" in
227+
// first two bits so in English this is: "if this byte is not a continuation
228+
// byte"
229+
if ((*c & 0xC0) != 0x80) {
230+
// this byte must be either a leading byte (11...) or an ASCII char
231+
// (0x...)
232+
// => reset our location, as we're starting a new UTF-8 codepoint
233+
str_len = 0;
234+
}
235+
236+
// append the current byte to the buffer
237+
str_buffer[str_len++] =
238+
*c; // ++ is post-increment, incremented after this line
239+
str_buffer[str_len] = '\0';
240+
241+
// while the next character is a continuation byte, continue appending
242+
// but if there are too many of them, just stop to avoid overruning
243+
// str_buffer size.
244+
if ((*(c + 1) & 0xC0) == 0x80 && str_len < 4) {
245+
continue;
246+
}
247+
248+
// ok c+1 is not a continuation byte, so we've read in a full codepoint
249+
int id = str_lookup(str_buffer, sorted_vocab_.get(), vocab_size_);
250+
if (id != -1) {
251+
// we found this codepoint in vocab, add it as a token
252+
tokens.push_back(id);
253+
} else {
254+
// byte_fallback encoding: just encode each byte as a token
255+
// +3 is here because the first 3 vocab elements are <unk>, <s>, </s>
256+
// so the individual bytes only start at index 3
257+
for (int i = 0; i < str_len; i++) {
258+
tokens.push_back((unsigned char)str_buffer[i] + 3);
259+
}
260+
}
261+
str_len = 0; // protect against a sequence of stray UTF8 continuation bytes
262+
}
263+
264+
// merge the best consecutive pair each iteration, according the scores in
265+
// vocab_scores
266+
while (1) {
267+
float best_score = -1e10;
268+
int best_id = -1;
269+
int best_idx = -1;
270+
271+
for (int i = 0; i < tokens.size() - 1; i++) {
272+
// check if we can merge the pair (tokens[i], tokens[i+1])
273+
snprintf(
274+
str_buffer,
275+
max_token_length_ * 2 + 3,
276+
"%s%s",
277+
vocab_[tokens[i]],
278+
vocab_[tokens[i + 1]]);
279+
int id = str_lookup(str_buffer, sorted_vocab_.get(), vocab_size_);
280+
if (id != -1 && vocab_scores_[id] > best_score) {
281+
// this merge pair exists in vocab! record its score and position
282+
best_score = vocab_scores_[id];
283+
best_id = id;
284+
best_idx = i;
285+
}
286+
}
287+
288+
if (best_idx == -1) {
289+
break; // we couldn't find any more pairs to merge, so we're done
290+
}
291+
292+
// merge the consecutive pair (best_idx, best_idx+1) into new token best_id
293+
tokens[best_idx] = best_id;
294+
// delete token at position best_idx+1, shift the entire sequence back 1
295+
for (int i = best_idx + 1; i < tokens.size() - 1; i++) {
296+
tokens[i] = tokens[i + 1];
297+
}
298+
tokens.pop_back(); // token length decreased
299+
}
300+
301+
// add optional EOS (=2) token, if desired
302+
if (eos >= 0) {
303+
while (eos--) {
304+
tokens.push_back(eos_tok_);
305+
}
306+
} else {
307+
TK_LOG(Error, "eos %d should be >= 0", eos);
308+
return Error::EncodeFailure;
309+
}
310+
311+
delete[] str_buffer;
312+
return Result(tokens);
313+
}
314+
315+
} // namespace tokenizers

targets.bzl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,3 +94,16 @@ def define_common_targets():
9494
"nlohmann_json",
9595
],
9696
)
97+
98+
runtime.cxx_library(
99+
name = "llama2c_tokenizer",
100+
srcs = [
101+
"src/llama2c_tokenizer.cpp",
102+
],
103+
exported_deps = [
104+
":headers",
105+
],
106+
visibility = [
107+
"@EXECUTORCH_CLIENTS",
108+
],
109+
)
16 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)