@@ -23,7 +23,7 @@ static int compare_tokens(const void* a, const void* b) {
23
23
return strcmp (((TokenIndex*)a)->str , ((TokenIndex*)b)->str );
24
24
}
25
25
26
- Tokenizer::Tokenizer (int32_t vocab_size, int32_t bos_tok, int32_t eos_tok)
26
+ Tokenizer::Tokenizer (int32_t vocab_size, uint64_t bos_tok, uint64_t eos_tok)
27
27
: initialized_(false ),
28
28
vocab_size_ (vocab_size),
29
29
bos_tok_(bos_tok),
@@ -142,10 +142,10 @@ Tokenizer::~Tokenizer() {
142
142
*
143
143
* @param prev_token The previous token.
144
144
* @param token The current token.
145
- * @return Result<const char* > A pointer to the string representation of the
145
+ * @return Result<std::string > A pointer to the string representation of the
146
146
* token.
147
147
*/
148
- Result<const char * > Tokenizer::decode (int32_t prev_token, int32_t token) {
148
+ Result<std::string > Tokenizer::decode (uint64_t prev_token, uint64_t token) {
149
149
if (!initialized_) {
150
150
ET_LOG (Error, " Tokenizer not initialized" );
151
151
return Error::NotSupported;
@@ -162,7 +162,8 @@ Result<const char*> Tokenizer::decode(int32_t prev_token, int32_t token) {
162
162
if (sscanf (piece, " <0x%02hhX>" , &byte_val) == 1 ) {
163
163
piece = (char *)byte_pieces_ + byte_val * 2 ;
164
164
}
165
- return piece;
165
+ std::string res (piece);
166
+ return res;
166
167
}
167
168
168
169
static int32_t
@@ -183,23 +184,19 @@ str_lookup(const char* str, TokenIndex* sorted_vocab, int32_t vocab_size) {
183
184
* @param eos The number of EOS to append to the token list.
184
185
* @param tokens The output tokens.
185
186
* @param n_tokens The number of tokens.
186
- * @return Error
187
+ * @return Result<std::vector<uint64_t>>
187
188
*/
188
- Error Tokenizer::encode (
189
- const char * text,
190
- int8_t bos,
191
- int8_t eos,
192
- int32_t * tokens,
193
- int32_t * n_tokens) {
189
+ Result<std::vector<uint64_t >>
190
+ Tokenizer::encode (const std::string& text, int8_t bos, int8_t eos) {
194
191
if (!initialized_) {
195
192
ET_LOG (Error, " Tokenizer not initialized" );
196
193
return Error::NotSupported;
197
194
}
198
195
// encode the string text (input) into an upper-bound preallocated tokens[]
199
196
// array bos != 0 means prepend the BOS token (=1), eos != 0 means append the
200
197
// EOS token (=2)
201
- if (text == nullptr ) {
202
- ET_LOG (Error, " cannot encode null text" );
198
+ if (text. empty () ) {
199
+ ET_LOG (Error, " cannot encode empty text" );
203
200
return Error::InvalidArgument;
204
201
}
205
202
@@ -210,12 +207,12 @@ Error Tokenizer::encode(
210
207
size_t str_len = 0 ;
211
208
212
209
// start at 0 tokens
213
- *n_tokens = 0 ;
210
+ std::vector< uint64_t > tokens ;
214
211
215
212
// add optional BOS token, if desired
216
213
if (bos > 0 ) {
217
214
while (bos--) {
218
- tokens[(*n_tokens)++] = bos_tok_;
215
+ tokens. push_back ( bos_tok_) ;
219
216
}
220
217
} else {
221
218
ET_LOG (Error, " bos %d should be >= 0" , bos);
@@ -230,7 +227,7 @@ Error Tokenizer::encode(
230
227
const char * space = " " ;
231
228
if (text[0 ] != ' \0 ' ) {
232
229
int dummy_prefix = str_lookup (space, sorted_vocab_.get (), vocab_size_);
233
- tokens[(*n_tokens)++] = dummy_prefix;
230
+ tokens. push_back ( dummy_prefix) ;
234
231
}
235
232
236
233
// Okay UTF-8 time. This will get messy. Here is the reference from Uncyclopedia:
@@ -242,7 +239,7 @@ Error Tokenizer::encode(
242
239
// U+10000 U+10FFFF 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx
243
240
244
241
// process the raw (UTF-8) byte sequence of the input string
245
- for (const char * c = text; *c != ' \0 ' ; c++) {
242
+ for (const char * c = text. c_str () ; *c != ' \0 ' ; c++) {
246
243
// reset buffer if the current byte is ASCII or a leading byte
247
244
// 0xC0 is 11000000, so (*c & 0xC0) keeps the first 2 bits and zeros the
248
245
// rest 0x80 is 10000000 in UTF-8, all continuation bytes start with "10" in
@@ -271,13 +268,13 @@ Error Tokenizer::encode(
271
268
int id = str_lookup (str_buffer, sorted_vocab_.get (), vocab_size_);
272
269
if (id != -1 ) {
273
270
// we found this codepoint in vocab, add it as a token
274
- tokens[(*n_tokens)++] = id ;
271
+ tokens. push_back (id) ;
275
272
} else {
276
273
// byte_fallback encoding: just encode each byte as a token
277
274
// +3 is here because the first 3 vocab elements are <unk>, <s>, </s>
278
275
// so the individual bytes only start at index 3
279
276
for (int i = 0 ; i < str_len; i++) {
280
- tokens[(*n_tokens)++] = (unsigned char )str_buffer[i] + 3 ;
277
+ tokens. push_back ( (unsigned char )str_buffer[i] + 3 ) ;
281
278
}
282
279
}
283
280
str_len = 0 ; // protect against a sequence of stray UTF8 continuation bytes
@@ -290,7 +287,7 @@ Error Tokenizer::encode(
290
287
int best_id = -1 ;
291
288
int best_idx = -1 ;
292
289
293
- for (int i = 0 ; i < (*n_tokens - 1 ) ; i++) {
290
+ for (int i = 0 ; i < tokens. size () - 1 ; i++) {
294
291
// check if we can merge the pair (tokens[i], tokens[i+1])
295
292
snprintf (
296
293
str_buffer,
@@ -314,24 +311,24 @@ Error Tokenizer::encode(
314
311
// merge the consecutive pair (best_idx, best_idx+1) into new token best_id
315
312
tokens[best_idx] = best_id;
316
313
// delete token at position best_idx+1, shift the entire sequence back 1
317
- for (int i = best_idx + 1 ; i < (*n_tokens - 1 ) ; i++) {
314
+ for (int i = best_idx + 1 ; i < tokens. size () - 1 ; i++) {
318
315
tokens[i] = tokens[i + 1 ];
319
316
}
320
- (*n_tokens)-- ; // token length decreased
317
+ tokens. pop_back () ; // token length decreased
321
318
}
322
319
323
320
// add optional EOS (=2) token, if desired
324
321
if (eos >= 0 ) {
325
322
while (eos--) {
326
- tokens[(*n_tokens)++] = eos_tok_;
323
+ tokens. push_back ( eos_tok_) ;
327
324
}
328
325
} else {
329
326
ET_LOG (Error, " eos %d should be >= 0" , eos);
330
327
return Error::InvalidArgument;
331
328
}
332
329
333
330
delete[] str_buffer;
334
- return Error::Ok ;
331
+ return Result (tokens) ;
335
332
}
336
333
337
334
} // namespace executor
0 commit comments