Skip to content

Commit 15144bd

Browse files
larryliu0820facebook-github-bot
authored andcommitted
Consolidate tokenizer interface (#2954)
Summary: Change the tokenizer APIs to: ``` Result<std::vector<uint64_t>> encode(const std::string& input, int8_t bos, int8_t eos); Result<std::string> decode(uint64_t prev_token, uint64_t token); ``` Notice that: we use `uint64_t` for token id just to be safe. We return a std::vector of tokens for encode() API. Differential Revision: D55944780
1 parent 564c276 commit 15144bd

File tree

3 files changed

+44
-53
lines changed

3 files changed

+44
-53
lines changed

examples/models/llama2/runner/runner.cpp

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -81,14 +81,14 @@ Error Runner::load() {
8181
if (tokenizer_->bos_tok() != bos_id_) {
8282
ET_LOG(
8383
Error,
84-
"Tokenizer's BOS id %d does not match model's BOS id %d, will override tokenizer's BOS.",
84+
"Tokenizer's BOS id %lu does not match model's BOS id %d, will override tokenizer's BOS.",
8585
tokenizer_->bos_tok(),
8686
bos_id_);
8787
}
8888
if (tokenizer_->eos_tok() != eos_id_) {
8989
ET_LOG(
9090
Error,
91-
"Tokenizer's EOS id %d does not match model's EOS id %d, will override tokenizer's EOS.",
91+
"Tokenizer's EOS id %lu does not match model's EOS id %d, will override tokenizer's EOS.",
9292
tokenizer_->eos_tok(),
9393
eos_id_);
9494
}
@@ -227,20 +227,18 @@ Error Runner::generate(
227227
stats_.inference_start_ms = util::time_in_ms();
228228
shouldStop_ = false;
229229

230-
// encode the (string) prompt into tokens sequence
231-
int num_prompt_tokens = 0;
232-
// max # of prompt tokens: len(prompt) + '\0', ?BOS, ?EOS
233-
int* prompt_tokens = new int[prompt.size() + 1 + n_bos_ + n_eos_];
234-
235230
// Set the sequence length to the max seq length if not provided
236231
seq_len = (seq_len > 0 && seq_len <= max_seq_len_) ? seq_len : max_seq_len_;
237232

238-
tokenizer_->encode(
239-
prompt.c_str(),
240-
n_bos_,
241-
append_eos_ ? n_eos_ : 0,
242-
prompt_tokens,
243-
&num_prompt_tokens);
233+
Result<std::vector<uint64_t>> encode_res =
234+
tokenizer_->encode(prompt, n_bos_, append_eos_ ? n_eos_ : 0);
235+
236+
ET_CHECK_OK_OR_RETURN_ERROR(
237+
encode_res.error(), "Failed to encode prompt %s", prompt.c_str());
238+
239+
// encode the (string) prompt into tokens sequence
240+
std::vector<uint64_t> prompt_tokens = encode_res.get();
241+
int num_prompt_tokens = prompt_tokens.size();
244242

245243
ET_CHECK_MSG(num_prompt_tokens >= 1, "Expected at least 1 prompt token");
246244
ET_CHECK_MSG(
@@ -303,13 +301,13 @@ Error Runner::generate(
303301

304302
// Print the prompt for consistent output between single token prefill and
305303
// batch prefill.
306-
int prev = prompt_tokens[0];
307-
int cur;
304+
uint64_t prev = prompt_tokens[0];
305+
uint64_t cur;
308306
for (int i = 1; i < num_prompt_tokens; i++) {
309307
cur = prompt_tokens[i];
310308
auto piece_res = tokenizer_->decode(prev, cur);
311309
ET_CHECK_OK_OR_RETURN_ERROR(piece_res.error());
312-
util::safe_printf(piece_res.get());
310+
util::safe_printf(piece_res.get().c_str());
313311
fflush(stdout);
314312
prev = cur;
315313
}
@@ -361,7 +359,7 @@ Error Runner::generate(
361359
// print the token as string, decode it with the Tokenizer object
362360
auto piece_res = tokenizer_->decode(prev_token, cur_token);
363361
ET_CHECK(piece_res.ok());
364-
const char* piece = piece_res.get();
362+
const char* piece = piece_res.get().c_str();
365363

366364
// same as printf("%s", piece), but skips "unsafe" bytes
367365
util::safe_printf(piece);
@@ -396,7 +394,6 @@ Error Runner::generate(
396394
stats_callback(stats_);
397395
}
398396

399-
delete[] prompt_tokens;
400397
return Error::Ok;
401398
}
402399

examples/models/llama2/tokenizer/tokenizer.cpp

Lines changed: 21 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ static int compare_tokens(const void* a, const void* b) {
2323
return strcmp(((TokenIndex*)a)->str, ((TokenIndex*)b)->str);
2424
}
2525

26-
Tokenizer::Tokenizer(int32_t vocab_size, int32_t bos_tok, int32_t eos_tok)
26+
Tokenizer::Tokenizer(int32_t vocab_size, uint64_t bos_tok, uint64_t eos_tok)
2727
: initialized_(false),
2828
vocab_size_(vocab_size),
2929
bos_tok_(bos_tok),
@@ -142,10 +142,10 @@ Tokenizer::~Tokenizer() {
142142
*
143143
* @param prev_token The previous token.
144144
* @param token The current token.
145-
* @return Result<const char*> A pointer to the string representation of the
145+
* @return Result<std::string> A pointer to the string representation of the
146146
* token.
147147
*/
148-
Result<const char*> Tokenizer::decode(int32_t prev_token, int32_t token) {
148+
Result<std::string> Tokenizer::decode(uint64_t prev_token, uint64_t token) {
149149
if (!initialized_) {
150150
ET_LOG(Error, "Tokenizer not initialized");
151151
return Error::NotSupported;
@@ -162,7 +162,8 @@ Result<const char*> Tokenizer::decode(int32_t prev_token, int32_t token) {
162162
if (sscanf(piece, "<0x%02hhX>", &byte_val) == 1) {
163163
piece = (char*)byte_pieces_ + byte_val * 2;
164164
}
165-
return piece;
165+
std::string res(piece);
166+
return res;
166167
}
167168

168169
static int32_t
@@ -183,23 +184,19 @@ str_lookup(const char* str, TokenIndex* sorted_vocab, int32_t vocab_size) {
183184
* @param eos The number of EOS to append to the token list.
184185
* @param tokens The output tokens.
185186
* @param n_tokens The number of tokens.
186-
* @return Error
187+
* @return Result<std::vector<uint64_t>>
187188
*/
188-
Error Tokenizer::encode(
189-
const char* text,
190-
int8_t bos,
191-
int8_t eos,
192-
int32_t* tokens,
193-
int32_t* n_tokens) {
189+
Result<std::vector<uint64_t>>
190+
Tokenizer::encode(const std::string& text, int8_t bos, int8_t eos) {
194191
if (!initialized_) {
195192
ET_LOG(Error, "Tokenizer not initialized");
196193
return Error::NotSupported;
197194
}
198195
// encode the string text (input) into an upper-bound preallocated tokens[]
199196
// array bos != 0 means prepend the BOS token (=1), eos != 0 means append the
200197
// EOS token (=2)
201-
if (text == nullptr) {
202-
ET_LOG(Error, "cannot encode null text");
198+
if (text.empty()) {
199+
ET_LOG(Error, "cannot encode empty text");
203200
return Error::InvalidArgument;
204201
}
205202

@@ -210,12 +207,12 @@ Error Tokenizer::encode(
210207
size_t str_len = 0;
211208

212209
// start at 0 tokens
213-
*n_tokens = 0;
210+
std::vector<uint64_t> tokens;
214211

215212
// add optional BOS token, if desired
216213
if (bos > 0) {
217214
while (bos--) {
218-
tokens[(*n_tokens)++] = bos_tok_;
215+
tokens.push_back(bos_tok_);
219216
}
220217
} else {
221218
ET_LOG(Error, "bos %d should be >= 0", bos);
@@ -230,7 +227,7 @@ Error Tokenizer::encode(
230227
const char* space = " ";
231228
if (text[0] != '\0') {
232229
int dummy_prefix = str_lookup(space, sorted_vocab_.get(), vocab_size_);
233-
tokens[(*n_tokens)++] = dummy_prefix;
230+
tokens.push_back(dummy_prefix);
234231
}
235232

236233
// Okay UTF-8 time. This will get messy. Here is the reference from Uncyclopedia:
@@ -242,7 +239,7 @@ Error Tokenizer::encode(
242239
// U+10000 U+10FFFF 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx
243240

244241
// process the raw (UTF-8) byte sequence of the input string
245-
for (const char* c = text; *c != '\0'; c++) {
242+
for (const char* c = text.c_str(); *c != '\0'; c++) {
246243
// reset buffer if the current byte is ASCII or a leading byte
247244
// 0xC0 is 11000000, so (*c & 0xC0) keeps the first 2 bits and zeros the
248245
// rest 0x80 is 10000000 in UTF-8, all continuation bytes start with "10" in
@@ -271,13 +268,13 @@ Error Tokenizer::encode(
271268
int id = str_lookup(str_buffer, sorted_vocab_.get(), vocab_size_);
272269
if (id != -1) {
273270
// we found this codepoint in vocab, add it as a token
274-
tokens[(*n_tokens)++] = id;
271+
tokens.push_back(id);
275272
} else {
276273
// byte_fallback encoding: just encode each byte as a token
277274
// +3 is here because the first 3 vocab elements are <unk>, <s>, </s>
278275
// so the individual bytes only start at index 3
279276
for (int i = 0; i < str_len; i++) {
280-
tokens[(*n_tokens)++] = (unsigned char)str_buffer[i] + 3;
277+
tokens.push_back((unsigned char)str_buffer[i] + 3);
281278
}
282279
}
283280
str_len = 0; // protect against a sequence of stray UTF8 continuation bytes
@@ -290,7 +287,7 @@ Error Tokenizer::encode(
290287
int best_id = -1;
291288
int best_idx = -1;
292289

293-
for (int i = 0; i < (*n_tokens - 1); i++) {
290+
for (int i = 0; i < tokens.size() - 1; i++) {
294291
// check if we can merge the pair (tokens[i], tokens[i+1])
295292
snprintf(
296293
str_buffer,
@@ -314,24 +311,24 @@ Error Tokenizer::encode(
314311
// merge the consecutive pair (best_idx, best_idx+1) into new token best_id
315312
tokens[best_idx] = best_id;
316313
// delete token at position best_idx+1, shift the entire sequence back 1
317-
for (int i = best_idx + 1; i < (*n_tokens - 1); i++) {
314+
for (int i = best_idx + 1; i < tokens.size() - 1; i++) {
318315
tokens[i] = tokens[i + 1];
319316
}
320-
(*n_tokens)--; // token length decreased
317+
tokens.pop_back(); // token length decreased
321318
}
322319

323320
// add optional EOS (=2) token, if desired
324321
if (eos >= 0) {
325322
while (eos--) {
326-
tokens[(*n_tokens)++] = eos_tok_;
323+
tokens.push_back(eos_tok_);
327324
}
328325
} else {
329326
ET_LOG(Error, "eos %d should be >= 0", eos);
330327
return Error::InvalidArgument;
331328
}
332329

333330
delete[] str_buffer;
334-
return Error::Ok;
331+
return Result(tokens);
335332
}
336333

337334
} // namespace executor

examples/models/llama2/tokenizer/tokenizer.h

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include <cstring>
1717
#include <memory>
1818
#include <string>
19+
#include <vector>
1920

2021
#include <executorch/runtime/core/error.h>
2122
#include <executorch/runtime/core/exec_aten/exec_aten.h>
@@ -32,37 +33,33 @@ struct TokenIndex {
3233

3334
class Tokenizer {
3435
public:
35-
explicit Tokenizer(int32_t vocab_size, int32_t bos_tok, int32_t eos_tok);
36+
explicit Tokenizer(int32_t vocab_size, uint64_t bos_tok, uint64_t eos_tok);
3637
~Tokenizer();
3738

3839
Error load(const std::string& tokenizer_path);
3940

40-
Error encode(
41-
const char* text,
42-
int8_t bos,
43-
int8_t eos,
44-
int32_t* tokens,
45-
int32_t* n_tokens);
41+
Result<std::vector<uint64_t>>
42+
encode(const std::string& input, int8_t bos, int8_t eos);
4643

47-
Result<const char*> decode(int prev_token, int token);
44+
Result<std::string> decode(uint64_t prev_token, uint64_t token);
4845

4946
// getters
5047
int32_t vocab_size() const {
5148
return vocab_size_;
5249
}
5350

54-
int32_t bos_tok() const {
51+
uint64_t bos_tok() const {
5552
return bos_tok_;
5653
}
5754

58-
int32_t eos_tok() const {
55+
uint64_t eos_tok() const {
5956
return eos_tok_;
6057
}
6158

6259
private:
6360
bool initialized_;
6461
const int32_t vocab_size_;
65-
int32_t bos_tok_, eos_tok_;
62+
uint64_t bos_tok_, eos_tok_;
6663
std::unique_ptr<char*[]> vocab_;
6764
std::unique_ptr<float[]> vocab_scores_;
6865
std::unique_ptr<TokenIndex[]> sorted_vocab_;

0 commit comments

Comments
 (0)