@@ -12,98 +12,77 @@ namespace torch {
12
12
namespace executor {
13
13
namespace {
14
14
static constexpr int32_t kSpecialTokensSize = 256 ;
15
+ static std::string kBOSToken = " <|begin_of_text|>" ;
16
+ static constexpr size_t kBOSTokenIndex = 0 ;
17
+ static std::string kEOSToken = " <|end_of_text|>" ;
18
+ static constexpr size_t kEOSTokenIndex = 1 ;
15
19
16
- static inline const Encoder _get_default_special_tokens (
17
- ssize_t num_base_tokens) {
18
- Encoder special_tokens;
19
- ssize_t special_token_count = 0 ;
20
- special_tokens.emplace (
21
- " <|begin_of_text|>" , num_base_tokens + special_token_count++);
22
- special_tokens.emplace (
23
- " <|end_of_text|>" , num_base_tokens + special_token_count++);
24
- special_tokens.emplace (
25
- " <|reserved_special_token_0|>" , num_base_tokens + special_token_count++);
26
- special_tokens.emplace (
27
- " <|reserved_special_token_1|>" , num_base_tokens + special_token_count++);
28
- special_tokens.emplace (
29
- " <|reserved_special_token_2|>" , num_base_tokens + special_token_count++);
30
- special_tokens.emplace (
31
- " <|reserved_special_token_3|>" , num_base_tokens + special_token_count++);
32
- special_tokens.emplace (
33
- " <|start_header_id|>" , num_base_tokens + special_token_count++);
34
- special_tokens.emplace (
35
- " <|end_header_id|>" , num_base_tokens + special_token_count++);
36
- special_tokens.emplace (
37
- " <|reserved_special_token_4|>" , num_base_tokens + special_token_count++);
38
- special_tokens.emplace (" <|eot_id|>" , num_base_tokens + special_token_count++);
20
+ static inline std::unique_ptr<std::vector<std::string>>
21
+ _get_default_special_tokens () {
22
+ auto special_tokens = std::make_unique<std::vector<std::string>>(
23
+ std::vector<std::string>{kBOSToken , kEOSToken });
24
+ special_tokens->emplace_back (" <|reserved_special_token_0|>" );
25
+ special_tokens->emplace_back (" <|reserved_special_token_1|>" );
26
+ special_tokens->emplace_back (" <|reserved_special_token_2|>" );
27
+ special_tokens->emplace_back (" <|reserved_special_token_3|>" );
28
+ special_tokens->emplace_back (" <|start_header_id|>" );
29
+ special_tokens->emplace_back (" <|end_header_id|>" );
30
+ special_tokens->emplace_back (" <|reserved_special_token_4|>" );
31
+ special_tokens->emplace_back (" <|eot_id|>" );
39
32
40
33
// pad the rest of the special tokens with reserved tokens
41
34
ssize_t reserved_special_token_num = 5 ;
42
- while (special_token_count < kSpecialTokensSize ) {
43
- special_tokens. emplace (
35
+ while (special_tokens-> size () < kSpecialTokensSize ) {
36
+ special_tokens-> emplace_back (
44
37
" <|reserved_special_token_" +
45
- std::to_string (reserved_special_token_num++) + " |>" ,
46
- num_base_tokens + special_token_count++);
38
+ std::to_string (reserved_special_token_num++) + " |>" );
47
39
}
48
40
return special_tokens;
49
41
}
50
42
51
- static inline const Encoder _get_multimodal_special_tokens (
52
- ssize_t num_base_tokens) {
53
- ssize_t special_token_count = 0 ;
54
- Encoder special_tokens;
55
- special_tokens.emplace (
56
- " <|begin_of_text|>" , num_base_tokens + special_token_count++);
57
- special_tokens.emplace (
58
- " <|end_of_text|>" , num_base_tokens + special_token_count++);
59
- special_tokens.emplace (
60
- " <|reserved_special_token_0|>" , num_base_tokens + special_token_count++);
61
- special_tokens.emplace (
62
- " <|reserved_special_token_1|>" , num_base_tokens + special_token_count++);
63
- special_tokens.emplace (
64
- " <|reserved_special_token_2|>" , num_base_tokens + special_token_count++);
65
- special_tokens.emplace (
66
- " <|reserved_special_token_3|>" , num_base_tokens + special_token_count++);
67
- special_tokens.emplace (
68
- " <|start_header_id|>" , num_base_tokens + special_token_count++);
69
- special_tokens.emplace (
70
- " <|end_header_id|>" , num_base_tokens + special_token_count++);
71
- special_tokens.emplace (" <|eom_id|>" , num_base_tokens + special_token_count++);
72
- special_tokens.emplace (" <|eot_id|>" , num_base_tokens + special_token_count++);
73
- special_tokens.emplace (" <|image|>" , num_base_tokens + special_token_count++);
43
+ static inline std::unique_ptr<std::vector<std::string>>
44
+ _get_multimodal_special_tokens () {
45
+ auto special_tokens = std::make_unique<std::vector<std::string>>(
46
+ std::vector<std::string>{kBOSToken , kEOSToken });
47
+ special_tokens->emplace_back (" <|reserved_special_token_0|>" );
48
+ special_tokens->emplace_back (" <|reserved_special_token_1|>" );
49
+ special_tokens->emplace_back (" <|reserved_special_token_2|>" );
50
+ special_tokens->emplace_back (" <|reserved_special_token_3|>" );
51
+ special_tokens->emplace_back (" <|start_header_id|>" );
52
+ special_tokens->emplace_back (" <|end_header_id|>" );
53
+ special_tokens->emplace_back (" <|eom_id|>" );
54
+ special_tokens->emplace_back (" <|eot_id|>" );
55
+ special_tokens->emplace_back (" <|image|>" );
74
56
75
57
// pad the rest of the special tokens with reserved tokens except the last
76
58
// one
77
59
ssize_t reserved_special_token_num = 4 ;
78
- while (special_token_count < kSpecialTokensSize - 1 ) {
79
- special_tokens. emplace (
60
+ while (special_tokens-> size () < kSpecialTokensSize - 1 ) {
61
+ special_tokens-> emplace_back (
80
62
" <|reserved_special_token_" +
81
- std::to_string (reserved_special_token_num++) + " |>" ,
82
- num_base_tokens + special_token_count++);
63
+ std::to_string (reserved_special_token_num++) + " |>" );
83
64
}
84
65
85
- special_tokens.emplace (
86
- " <|python_tag|>" , num_base_tokens + special_token_count++);
66
+ special_tokens->emplace_back (" <|python_tag|>" );
87
67
88
68
return special_tokens;
89
69
}
90
- } // namespace
91
70
92
- const Encoder LlamaTiktoken::get_special_tokens ( ssize_t num_base_tokens) const {
93
- switch (_version ) {
71
+ std::unique_ptr<std::vector<std::string>> _get_special_tokens (Version version) {
72
+ switch (version ) {
94
73
case MULTIMODAL:
95
- return _get_multimodal_special_tokens (num_base_tokens );
74
+ return _get_multimodal_special_tokens ();
96
75
default :
97
- return _get_default_special_tokens (num_base_tokens );
76
+ return _get_default_special_tokens ();
98
77
}
99
78
}
100
79
101
- const std::string LlamaTiktoken::get_bos_token () const {
102
- return " <|begin_of_text|>" ;
103
- }
80
+ } // namespace
104
81
105
- const std::string LlamaTiktoken::get_eos_token () const {
106
- return " <|end_of_text|>" ;
82
+ std::unique_ptr<Tiktoken> get_tiktoken_for_llama (Version version) {
83
+ return std::make_unique<Tiktoken>(
84
+ _get_special_tokens (version), kBOSTokenIndex , kEOSTokenIndex );
107
85
}
86
+
108
87
} // namespace executor
109
88
} // namespace torch
0 commit comments