Skip to content

Commit a91c122

Browse files
committed
Improve support for special tokens
1 parent b24c304 commit a91c122

File tree

4 files changed

+272
-15
lines changed

4 files changed

+272
-15
lines changed

convert.py

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -133,18 +133,20 @@ def make_tensors_list() -> List[str]:
133133
@dataclass
134134
class Params:
135135
n_vocab: int
136+
n_vocab_sp: int
136137
n_embd: int
137138
n_mult: int
138139
n_head: int
139140
n_layer: int
140141
file_type: GGMLFileType
141142

142143
@staticmethod
143-
def guessed(model: 'LazyModel', file_type: GGMLFileType) -> 'Params':
144+
def guessed(model: 'LazyModel', vocab: 'Vocab', file_type: GGMLFileType) -> 'Params':
144145
n_vocab, n_embd = model["tok_embeddings.weight"].shape
145146

146147
return Params(
147148
n_vocab=n_vocab,
149+
n_vocab_sp=vocab.vocab_special_size,
148150
n_embd=n_embd,
149151
n_mult=256,
150152
n_head=n_embd // 128,
@@ -154,7 +156,7 @@ def guessed(model: 'LazyModel', file_type: GGMLFileType) -> 'Params':
154156

155157

156158
class SentencePieceVocab:
157-
def __init__(self, fname_tokenizer: Path, fname_added_tokens: Optional[Path]) -> None:
159+
def __init__(self, fname_tokenizer: Path, fname_added_tokens: Optional[Path], fname_special_tokens: Optional[Path]) -> None:
158160
self.sentencepiece_tokenizer = SentencePieceProcessor(str(fname_tokenizer))
159161
added_tokens: Dict[str, int]
160162
if fname_added_tokens is not None:
@@ -172,13 +174,24 @@ def __init__(self, fname_tokenizer: Path, fname_added_tokens: Optional[Path]) ->
172174
self.vocab_size: int = self.vocab_size_base + len(self.added_tokens_list)
173175
self.fname_tokenizer = fname_tokenizer
174176
self.fname_added_tokens = fname_added_tokens
177+
special_tokens: Dict[str, Dict[str, Any]]
178+
if fname_special_tokens is not None:
179+
special_tokens = json.load(open(fname_special_tokens))
180+
else:
181+
special_tokens = {}
182+
token_name_to_id = {"unk_token": self.sentencepiece_tokenizer.unk_id(), "bos_token": self.sentencepiece_tokenizer.bos_id(), "eos_token": self.sentencepiece_tokenizer.eos_id(), "pad_token": self.sentencepiece_tokenizer.pad_id()}
183+
self.special_tokens_map = {token_name_to_id[token_name]: info["content"] if isinstance(info, dict) else info for token_name, info in special_tokens.items() if token_name in token_name_to_id and token_name_to_id[token_name] != -1}
184+
self.vocab_special_size: int = len(self.added_tokens_list) + len(self.special_tokens_map)
175185

176186
def sentencepiece_tokens(self) -> Iterable[Tuple[bytes, float]]:
177187
tokenizer = self.sentencepiece_tokenizer
188+
special_tokens = [tokenizer.bos_id(), tokenizer.eos_id(), tokenizer.pad_id()]
178189
for i in range(tokenizer.vocab_size()):
179190
text: bytes
180191
if tokenizer.is_unknown(i):
181-
text = " \u2047 ".encode("utf-8")
192+
text = self.special_tokens_map.get(i, " \u2047 ").encode("utf-8")
193+
elif i in special_tokens:
194+
text = self.special_tokens_map.get(i, "").encode("utf-8")
182195
elif tokenizer.is_control(i):
183196
text = b""
184197
elif tokenizer.is_byte(i):
@@ -201,18 +214,29 @@ def all_tokens(self) -> Iterable[Tuple[bytes, float]]:
201214
yield from self.sentencepiece_tokens()
202215
yield from self.added_tokens()
203216

217+
def all_special_tokens(self) -> Iterable[int]:
218+
for token_id in self.special_tokens_map.keys():
219+
yield token_id
220+
for i in range(len(self.added_tokens_list)):
221+
yield self.vocab_size_base + i
222+
204223
def __repr__(self) -> str:
205224
return f"<SentencePieceVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
206225

207226

208227
class GGMLVocab:
209228
def __init__(self, tokens: List[Tuple[bytes, float]]):
210229
self.tokens = tokens
230+
self.special_tokens = []
211231
self.vocab_size = len(tokens)
232+
self.vocab_special_size = 0
212233

213234
def all_tokens(self) -> Iterable[Tuple[bytes, float]]:
214235
return self.tokens
215236

237+
def all_special_tokens(self) -> Iterable[int]:
238+
return self.special_tokens
239+
216240
def __repr__(self) -> str:
217241
return f"<GGMLVocab with {self.vocab_size} tokens>"
218242

@@ -923,8 +947,9 @@ def __init__(self, fname_out: Path) -> None:
923947
def write_file_header(self, params: Params) -> None:
924948
self.fout.write(b"ggjt"[::-1]) # magic
925949
values = [
926-
1, # file version
950+
4, # file version
927951
params.n_vocab,
952+
params.n_vocab_sp,
928953
params.n_embd,
929954
params.n_mult,
930955
params.n_head,
@@ -946,11 +971,13 @@ def write_vocab(self, vocab: Vocab) -> None:
946971
self.fout.write(struct.pack("i", len(text)))
947972
self.fout.write(text)
948973
self.fout.write(struct.pack("f", score))
974+
for token_id in vocab.all_special_tokens():
975+
self.fout.write(struct.pack("i", token_id))
949976

950977
@staticmethod
951978
def write_vocab_only(fname_out: Path, vocab: Vocab) -> None:
952979
of = OutputFile(fname_out)
953-
params = Params(n_vocab=vocab.vocab_size, n_embd=0, n_mult=0,
980+
params = Params(n_vocab=vocab.vocab_size, n_vocab_sp=vocab.vocab_special_size, n_embd=0, n_mult=0,
954981
n_head=1, n_layer=0, file_type=GGMLFileType.AllF32)
955982
of = OutputFile(fname_out)
956983
of.write_file_header(params)
@@ -1103,8 +1130,10 @@ def load_vocab(path: Path) -> SentencePieceVocab:
11031130
f"Could not find tokenizer.model in {path} or its parent; "
11041131
"if it's in another directory, pass the directory as --vocab-dir")
11051132
added_tokens_path = path.parent / "added_tokens.json"
1133+
special_tokens_path = path.parent / "special_tokens_map.json"
1134+
tokenizer_config_path = path.parent / "tokenizer_config.json"
11061135
print(f"Loading vocab file {path}")
1107-
return SentencePieceVocab(path, added_tokens_path if added_tokens_path.exists() else None)
1136+
return SentencePieceVocab(path, added_tokens_path if added_tokens_path.exists() else None, special_tokens_path if special_tokens_path.exists() else tokenizer_config_path if tokenizer_config_path.exists() else None)
11081137

11091138

11101139
def default_outfile(model_paths: List[Path], params: Params) -> Path:
@@ -1168,7 +1197,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
11681197
model = do_necessary_conversions(model)
11691198
output_type = pick_output_type(model, args.outtype)
11701199
model = convert_to_output_type(model, output_type)
1171-
params = Params.guessed(model, output_type)
1200+
params = Params.guessed(model, vocab, output_type)
11721201
outfile = args.outfile or default_outfile(model_plus.paths, params)
11731202
OutputFile.write_all(outfile, params, model, vocab)
11741203
print(f"Wrote {outfile}")

llama-util.h

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
#include <string>
1616
#include <vector>
17+
#include <map>
18+
#include <unordered_map>
1719
#include <stdexcept>
1820

1921
#ifdef __has_include
@@ -487,4 +489,166 @@ struct llama_ctx_buffer {
487489
typedef llama_buffer llama_ctx_buffer;
488490
#endif
489491

492+
struct llama_trie_node {
493+
llama_trie_node(): is_terminator(false) {}
494+
495+
std::unordered_map<char, llama_trie_node*> children;
496+
bool is_terminator;
497+
};
498+
499+
// Trie in C++. Creates a Trie out of a list of words. The trie is used to split on multiple delimiters in one pass
500+
// Ported from: https://github.com/huggingface/transformers/blob/ee88ae59940fd4b2c8fc119373143d7a1175c651/src/transformers/tokenization_utils.py#L52
501+
struct llama_trie {
502+
public:
503+
llama_trie(): root_(new llama_trie_node()) {}
504+
505+
void add(const std::string & word) {
506+
if (word.empty()) {
507+
return;
508+
}
509+
510+
llama_trie_node *ref = root_;
511+
for (char c : word) {
512+
if (ref->children.find(c) == ref->children.end()) {
513+
ref->children[c] = new llama_trie_node();
514+
}
515+
ref = ref->children[c];
516+
}
517+
ref->is_terminator = true;
518+
}
519+
520+
// Will look for the words added to the trie within `text`. Output is the boundaries of the words found.
521+
// Note that this trie will match the longest possible word first!
522+
std::vector<int> split(const std::string & text) const {
523+
std::map<int, llama_trie_node*> states;
524+
std::vector<int> offsets{0};
525+
526+
int skip = 0;
527+
for (int current = 0; current < text.size(); current++) {
528+
char current_char = text[current];
529+
if (skip > 0 && current < skip) {
530+
// Prevents the lookahead for matching twice
531+
// like extra_id_100 and id_100
532+
continue;
533+
}
534+
535+
// Whenever we found a match, we need to drop everything
536+
// this is a greedy algorithm, it will match on the first found token
537+
bool reset = false;
538+
539+
// In this case, we already have partial matches (But unfinished)
540+
for (auto state = states.begin(); state != states.end(); ) {
541+
int start = state->first;
542+
llama_trie_node *trie_pointer = state->second;
543+
if (trie_pointer->is_terminator) {
544+
// This is a final match, we need to reset and
545+
// store the results in `offsets`.
546+
547+
// Lookahead to match longest first
548+
// Important in case of extra_id_1 vs extra_id_100
549+
// Here we are also actively looking for other earlier partial
550+
// matches
551+
// "[CLS]", "L", we need to match CLS even if L is special
552+
int end = 0;
553+
for (const auto & look : states) {
554+
int lookstart = look.first;
555+
llama_trie_node *looktrie_pointer = look.second;
556+
int lookahead_index = 0;
557+
if (lookstart > start) {
558+
// This partial match is later, we can stop looking
559+
break;
560+
}
561+
if (lookstart < start) {
562+
// This partial match is earlier, the trie pointer
563+
// was already updated, so index is + 1
564+
lookahead_index = current + 1;
565+
end = current + 1;
566+
} else {
567+
// Here lookstart == start and
568+
// looktrie_pointer == trie_pointer
569+
// It wasn't updated yet so indices are current ones
570+
lookahead_index = current;
571+
end = current;
572+
}
573+
char next_char = lookahead_index < text.size() ? text[lookahead_index] : '\0';
574+
if (looktrie_pointer->is_terminator) {
575+
start = lookstart;
576+
end = lookahead_index;
577+
skip = lookahead_index;
578+
}
579+
580+
auto looktrie_pointer_it = looktrie_pointer->children.find(next_char);
581+
while (looktrie_pointer_it != looktrie_pointer->children.end()) {
582+
looktrie_pointer = looktrie_pointer_it->second;
583+
lookahead_index++;
584+
if (looktrie_pointer->is_terminator) {
585+
start = lookstart;
586+
end = lookahead_index;
587+
skip = lookahead_index;
588+
}
589+
590+
if (lookahead_index == text.size()) {
591+
// End of string
592+
break;
593+
}
594+
next_char = text[lookahead_index];
595+
looktrie_pointer_it = looktrie_pointer->children.find(next_char);
596+
}
597+
}
598+
599+
offsets.push_back(start);
600+
offsets.push_back(end);
601+
reset = true;
602+
break;
603+
}
604+
605+
auto trie_pointer_it = trie_pointer->children.find(current_char);
606+
if (trie_pointer_it != trie_pointer->children.end()) {
607+
// The current character being looked at has a match within the trie
608+
// update the pointer (it will be stored back into states later).
609+
trie_pointer = trie_pointer_it->second;
610+
states[start] = trie_pointer;
611+
++state;
612+
} else {
613+
// The new character has not match in the trie, we need
614+
// to stop keeping track of this partial match.
615+
state = states.erase(state);
616+
}
617+
}
618+
619+
if (reset) {
620+
// Clear the full start (we found a real match)
621+
states.clear();
622+
}
623+
624+
// If this character is a starting character within the trie
625+
// start keeping track of this partial match.
626+
auto children_it = root_->children.find(current_char);
627+
if (current >= skip && children_it != root_->children.end()) {
628+
states[current] = children_it->second;
629+
}
630+
}
631+
632+
// We have a cut at the end with states.
633+
for (const auto & state : states) {
634+
int start = state.first;
635+
llama_trie_node *trie_pointer = state.second;
636+
if (trie_pointer->is_terminator) {
637+
// This is a final match, we need to reset and
638+
// store the results in `offsets`.
639+
int end = text.size();
640+
offsets.push_back(start);
641+
offsets.push_back(end);
642+
break;
643+
}
644+
}
645+
646+
offsets.push_back(text.size());
647+
return offsets;
648+
}
649+
650+
private:
651+
llama_trie_node *root_;
652+
};
653+
490654
#endif

0 commit comments

Comments
 (0)