1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ * All rights reserved.
4
+ *
5
+ * This source code is licensed under the BSD-style license found in the
6
+ * LICENSE file in the root directory of this source tree.
7
+ */
8
+
9
+ /* *
10
+ * This is a simple tool to instantiate a tokenizer and run it over some text.
11
+ * It can be used to evaluate the tokenization done by a given tokenizer model
12
+ * relative to its native python library.
13
+ */
14
+
15
+ // Standard
16
+ #include < iostream>
17
+ #include < memory>
18
+ #include < sstream>
19
+
20
+ // Local
21
+ #include " hf_tokenizer.h"
22
+ #include " sentencepiece.h"
23
+ #include " tiktoken.h"
24
+
25
+ using namespace tokenizers ;
26
+
27
+ std::string help (char * argv[]) {
28
+ std::stringstream ss;
29
+ ss << " Usage: " << argv[0 ] << " <type> <model> <input to tokenize...>" << std::endl << std::endl;
30
+ ss << " Types:\n " << std::endl;
31
+ ss << " * sentencepiece: SPTokenizer" << std::endl;
32
+ ss << " * tiktoken: Tiktoken" << std::endl;
33
+ ss << " * hf_tokenizers: HFTokenizer" << std::endl;
34
+ return ss.str ();
35
+ }
36
+
37
+ int main (int argc, char * argv[]) {
38
+
39
+ // Check for the right number of CLI args
40
+ if (argc < 4 ) {
41
+ std::cerr << help (argv) << std::endl;
42
+ return 1 ;
43
+ }
44
+
45
+ // Parse CLI args
46
+ const std::string tokenizer_type (argv[1 ]);
47
+ const std::string model_path (argv[2 ]);
48
+ std::stringstream prompt_ss;
49
+ for (auto i = 3 ; i < argc; ++i) {
50
+ if (i > 3 ) {
51
+ prompt_ss << " " ;
52
+ }
53
+ prompt_ss << argv[i];
54
+ }
55
+ const std::string prompt = prompt_ss.str ();
56
+
57
+ // Instantiate the tokenizer
58
+ std::unique_ptr<Tokenizer> tok_ptr;
59
+ if (tokenizer_type == " sentencepiece" ) {
60
+ tok_ptr.reset (new SPTokenizer ());
61
+ } else if (tokenizer_type == " tiktoken" ) {
62
+ tok_ptr.reset (new Tiktoken ());
63
+ } else if (tokenizer_type == " hf_tokenizer" ) {
64
+ tok_ptr.reset (new HFTokenizer ());
65
+ } else {
66
+ std::stringstream ss;
67
+ ss << " ERROR: Invalid tokenizer type: " << tokenizer_type << std::endl << std::endl;
68
+ ss << help (argv);
69
+ std::cerr << ss.str () << std::endl;
70
+ return 1 ;
71
+ }
72
+
73
+ // Load from the path
74
+ tok_ptr->load (model_path);
75
+
76
+ // Log out the IDs for the BOS/EOS tokens
77
+ std::cout << " Vocab Size: " << tok_ptr->vocab_size () << std::endl;
78
+ std::cout << " BOS: " << tok_ptr->bos_tok () << std::endl;
79
+ std::cout << " EOS: " << tok_ptr->eos_tok () << std::endl << std::endl;
80
+
81
+ // Encode
82
+ std::cout << " PROMPT:" << std::endl << prompt << std::endl << std::endl;
83
+ std::cout << " Encoding..." << std::endl;
84
+ const auto encoded_result = tok_ptr->encode (prompt, 0 , 0 );
85
+ const auto encoded = encoded_result.get ();
86
+ std::cout << " [" ;
87
+ for (const auto tok_id : encoded) {
88
+ std::cout << " " << tok_id;
89
+ }
90
+ std::cout << " ]" << std::endl << std::endl;
91
+
92
+ // Decode
93
+ std::cout << " Decoding..." << std::endl;
94
+ uint64_t prev = tok_ptr->bos_tok ();
95
+ for (const auto & current : encoded) {
96
+ const auto decoded_result = tok_ptr->decode (prev, current);
97
+ std::cout << decoded_result.get ();
98
+ prev = current;
99
+ }
100
+ std::cout << std::endl;
101
+
102
+ return 0 ;
103
+ }
0 commit comments