@@ -275,41 +275,57 @@ std::vector<gpt_vocab::id> gpt_tokenize(const gpt_vocab & vocab, const std::stri
275
275
return tokens;
276
276
}
277
277
278
+ // TODO: Calculate this constant from the vocabulary
279
+ #define MAX_TOKEN_LEN 18
280
+ // SentencePiece implementation after https://guillaume-be.github.io/2020-05-30/sentence_piece
278
281
std::vector<gpt_vocab::id> llama_tokenize (const gpt_vocab & vocab, const std::string & text, bool bos) {
279
- // auto res = gpt_tokenize(vocab, text);
280
-
281
- // if (bos) {
282
- // res.insert(res.begin(), 1); // TODO: replace with vocab.bos
283
- // }
284
-
285
282
std::vector<gpt_vocab::id> res;
286
-
287
- if (bos) {
288
- res.push_back (1 ); // TODO: replace with vocab.bos
289
- }
290
-
291
- // find the longest token that matches the text
292
- int pos = 0 ;
293
- while (true ) {
294
- int l = 0 ;
295
- int t = 0 ;
296
- for (const auto & kv : vocab.id_to_token ) {
297
- if (kv.second .size () < l) continue ;
298
- if (kv.second .size () > text.size () - pos) continue ;
299
- if (text.substr (pos, kv.second .size ()) == kv.second ) {
300
- l = kv.second .size ();
301
- t = kv.first ;
283
+ std::vector<int > score;
284
+ std::vector<gpt_vocab::id> prev;
285
+ int len = text.length ();
286
+
287
+ score.resize (len + 1 );
288
+ prev.resize (len + 1 );
289
+
290
+ // Forward pass
291
+ for (int i = 0 ; i < len; i++) {
292
+ int max_len = std::min (len - i, MAX_TOKEN_LEN);
293
+ for (int sub_len = 1 ; sub_len <= len - i; sub_len++) {
294
+ auto sub = text.substr (i, sub_len);
295
+ auto token = vocab.token_to_id .find (sub);
296
+ if (token != vocab.token_to_id .end ()) {
297
+ int token_score = sub.length () * sub.length ();
298
+ int local_score = score[i] + token_score;
299
+ int next = i + sub_len;
300
+ if (score[next] < local_score) {
301
+ score[next] = local_score;
302
+ prev[next] = (*token).second ;
303
+ }
302
304
}
303
305
}
306
+ }
304
307
305
- if (l == 0 ) {
306
- break ;
308
+ // Backward pass
309
+ int i = len;
310
+ while (i > 0 ) {
311
+ gpt_vocab::id token_id = prev[i];
312
+ if (token_id == 0 ) {
313
+ // TODO: Return error or something more meaningful
314
+ printf (" failed to tokenize string!\n " );
315
+ break ;
307
316
}
317
+ res.push_back (token_id);
318
+ auto token = (*vocab.id_to_token .find (token_id)).second ;
319
+ i -= token.length ();
320
+ }
308
321
309
- res. push_back (t);
310
- pos += l;
322
+ if (bos) {
323
+ res. push_back ( 1 ); // TODO: replace with vocab.bos
311
324
}
312
325
326
+ // Pieces are in reverse order so correct that
327
+ std::reverse (res.begin (), res.end ());
328
+
313
329
return res;
314
330
}
315
331
0 commit comments