Skip to content

Commit 61a98bc

Browse files
committed
Improve support for special tokens
1 parent 93356bd commit 61a98bc

File tree

4 files changed

+297
-36
lines changed

4 files changed

+297
-36
lines changed

convert.py

Lines changed: 61 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ def find_n_mult(n_ff: int, n_embd: int) -> int:
142142
@dataclass
143143
class Params:
144144
n_vocab: int
145+
n_vocab_sp:int
145146
n_embd: int
146147
n_mult: int
147148
n_head: int
@@ -169,6 +170,7 @@ def guessed(model: 'LazyModel') -> 'Params':
169170

170171
return Params(
171172
n_vocab = n_vocab,
173+
n_vocab_sp= n_vocab,
172174
n_embd = n_embd,
173175
n_mult = 256,
174176
n_head = n_head,
@@ -191,6 +193,7 @@ def loadHFTransformerJson(model: 'LazyModel', config_path: 'Path') -> 'Params':
191193

192194
return Params(
193195
n_vocab = n_vocab,
196+
n_vocab_sp= n_vocab,
194197
n_embd = n_embd,
195198
n_mult = n_mult,
196199
n_head = n_head,
@@ -215,6 +218,7 @@ def loadOriginalParamsJson(model: 'LazyModel', config_path: 'Path') -> 'Params':
215218

216219
return Params(
217220
n_vocab = n_vocab,
221+
n_vocab_sp= n_vocab
218222
n_embd = n_embd,
219223
n_mult = n_mult,
220224
n_head = n_head,
@@ -239,7 +243,7 @@ def load(model_plus: 'ModelPlus') -> 'Params':
239243

240244

241245
class SentencePieceVocab:
242-
def __init__(self, fname_tokenizer: Path, fname_added_tokens: Optional[Path], vocabtype: Optional[str]) -> None:
246+
def __init__(self, fname_tokenizer: Path, fname_added_tokens: Optional[Path], fname_special_tokens: Optional[Path], vocabtype: Optional[str]) -> None:
243247
self.vocabtype = vocabtype
244248
if self.vocabtype == "bpe":
245249
self.sentencepiece_tokenizer = json.loads(open(str(fname_tokenizer)).read())
@@ -264,35 +268,46 @@ def __init__(self, fname_tokenizer: Path, fname_added_tokens: Optional[Path], vo
264268
self.vocab_size: int = self.vocab_size_base + len(self.added_tokens_list)
265269
self.fname_tokenizer = fname_tokenizer
266270
self.fname_added_tokens = fname_added_tokens
271+
special_tokens: Dict[str, Dict[str, Any]]
272+
if fname_special_tokens is not None:
273+
special_tokens = json.load(open(fname_special_tokens))
274+
else:
275+
special_tokens = {}
276+
token_name_to_id = {"unk_token": self.sentencepiece_tokenizer.unk_id(), "bos_token": self.sentencepiece_tokenizer.bos_id(), "eos_token": self.sentencepiece_tokenizer.eos_id(), "pad_token": self.sentencepiece_tokenizer.pad_id()}
277+
self.special_tokens_map = {token_name_to_id[token_name]: info["content"] if isinstance(info, dict) else info for token_name, info in special_tokens.items() if token_name in token_name_to_id and token_name_to_id[token_name] != -1}
278+
self.vocab_special_size: int = len(self.added_tokens_list) + len(self.special_tokens_map)
267279

268280
def sentencepiece_tokens(self) -> Iterable[Tuple[bytes, float]]:
269281
tokenizer = self.sentencepiece_tokenizer
270282
if self.vocabtype == "bpe":
271-
from transformers.models.gpt2 import tokenization_gpt2
272-
byte_encoder = tokenization_gpt2.bytes_to_unicode()
273-
byte_decoder = {v: k for k, v in byte_encoder.items()}
274-
for i, item in enumerate(tokenizer):
275-
text: bytes
276-
text = b''.join([x.to_bytes(1, byteorder='big') for x in [byte_decoder[y] for y in item]])
277-
score: float = -i
278-
yield text, score
283+
from transformers.models.gpt2 import tokenization_gpt2
284+
byte_encoder = tokenization_gpt2.bytes_to_unicode()
285+
byte_decoder = {v: k for k, v in byte_encoder.items()}
286+
for i, item in enumerate(tokenizer):
287+
text: bytes
288+
text = b''.join([x.to_bytes(1, byteorder='big') for x in [byte_decoder[y] for y in item]])
289+
score: float = -i
290+
yield text, score
279291
else:
280-
for i in range(tokenizer.vocab_size()):
281-
text: bytes
282-
if tokenizer.is_unknown(i):
283-
text = " \u2047 ".encode("utf-8")
284-
elif tokenizer.is_control(i):
285-
text = b""
286-
elif tokenizer.is_byte(i):
287-
piece = tokenizer.id_to_piece(i)
288-
if len(piece) != 6:
289-
raise Exception(f"Invalid token: {piece}")
290-
byte_value = int(piece[3:-1], 16)
291-
text = struct.pack("B", byte_value)
292-
else:
293-
text = tokenizer.id_to_piece(i).replace("\u2581", " ").encode("utf-8")
294-
score: float = tokenizer.get_score(i)
295-
yield text, score
292+
special_tokens = [tokenizer.bos_id(), tokenizer.eos_id(), tokenizer.pad_id()]
293+
for i in range(tokenizer.vocab_size()):
294+
text: bytes
295+
if tokenizer.is_unknown(i):
296+
text = self.special_tokens_map.get(i, " \u2047 ").encode("utf-8")
297+
elif i in special_tokens:
298+
text = self.special_tokens_map.get(i, "").encode("utf-8")
299+
elif tokenizer.is_control(i):
300+
text = b""
301+
elif tokenizer.is_byte(i):
302+
piece = tokenizer.id_to_piece(i)
303+
if len(piece) != 6:
304+
raise Exception(f"Invalid token: {piece}")
305+
byte_value = int(piece[3:-1], 16)
306+
text = struct.pack("B", byte_value)
307+
else:
308+
text = tokenizer.id_to_piece(i).replace("\u2581", " ").encode("utf-8")
309+
score: float = tokenizer.get_score(i)
310+
yield text, score
296311

297312
def added_tokens(self) -> Iterable[Tuple[bytes, float]]:
298313
for text in self.added_tokens_list:
@@ -303,18 +318,29 @@ def all_tokens(self) -> Iterable[Tuple[bytes, float]]:
303318
yield from self.sentencepiece_tokens()
304319
yield from self.added_tokens()
305320

321+
def all_special_tokens(self) -> Iterable[int]:
322+
for token_id in self.special_tokens_map.keys():
323+
yield token_id
324+
for i in range(len(self.added_tokens_list)):
325+
yield self.vocab_size_base + i
326+
306327
def __repr__(self) -> str:
307328
return f"<SentencePieceVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
308329

309330

310331
class GGMLVocab:
311332
def __init__(self, tokens: List[Tuple[bytes, float]]):
312333
self.tokens = tokens
334+
self.special_tokens = []
313335
self.vocab_size = len(tokens)
336+
self.vocab_special_size = 0
314337

315338
def all_tokens(self) -> Iterable[Tuple[bytes, float]]:
316339
return self.tokens
317340

341+
def all_special_tokens(self) -> Iterable[int]:
342+
return self.special_tokens
343+
318344
def __repr__(self) -> str:
319345
return f"<GGMLVocab with {self.vocab_size} tokens>"
320346

@@ -1066,8 +1092,9 @@ def __init__(self, fname_out: Path) -> None:
10661092
def write_file_header(self, params: Params, file_type: GGMLFileType) -> None:
10671093
self.fout.write(b"ggjt"[::-1]) # magic
10681094
values = [
1069-
1, # file version
1095+
4, # file version
10701096
params.n_vocab,
1097+
params.n_vocab_sp,
10711098
params.n_embd,
10721099
params.n_mult,
10731100
params.n_head,
@@ -1089,11 +1116,14 @@ def write_vocab(self, vocab: Vocab) -> None:
10891116
self.fout.write(struct.pack("i", len(text)))
10901117
self.fout.write(text)
10911118
self.fout.write(struct.pack("f", score))
1119+
for token_id in vocab.all_special_tokens():
1120+
self.fout.write(struct.pack("i", token_id))
10921121

10931122
@staticmethod
10941123
def write_vocab_only(fname_out: Path, vocab: Vocab) -> None:
10951124
of = OutputFile(fname_out)
1096-
params = Params(n_vocab=vocab.vocab_size, n_embd=0, n_mult=0, n_head=1, n_layer=0)
1125+
params = Params(n_vocab=vocab.vocab_size, n_vocab_sp=vocab.vocab_special_size, n_embd=0, n_mult=0,
1126+
n_head=1, n_layer=0)
10971127
of = OutputFile(fname_out)
10981128
of.write_file_header(params, file_type=GGMLFileType.AllF32)
10991129
of.write_vocab(vocab)
@@ -1249,8 +1279,10 @@ def load_vocab(path: Path, vocabtype: Optional[str]) -> SentencePieceVocab:
12491279
f"Could not find tokenizer.model in {path} or its parent; "
12501280
"if it's in another directory, pass the directory as --vocab-dir")
12511281
added_tokens_path = path.parent / "added_tokens.json"
1282+
special_tokens_path = path.parent / "special_tokens_map.json"
1283+
tokenizer_config_path = path.parent / "tokenizer_config.json"
12521284
print(f"Loading vocab file {path}")
1253-
return SentencePieceVocab(path, added_tokens_path if added_tokens_path.exists() else None,
1285+
return SentencePieceVocab(path, added_tokens_path if added_tokens_path.exists() else None, special_tokens_path if special_tokens_path.exists() else tokenizer_config_path if tokenizer_config_path.exists() else None,
12541286
vocabtype)
12551287

12561288

@@ -1313,6 +1345,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
13131345
vocab_dir = args.vocab_dir if args.vocab_dir else model_plus.paths[0].parent
13141346
vocab = load_vocab(vocab_dir, args.vocabtype)
13151347
params = Params.load(model_plus)
1348+
params.n_vocab_sp = vocab.vocab_special_size
13161349
model = model_plus.model
13171350
model = do_necessary_conversions(model, params)
13181351
output_type = pick_output_type(model, args.outtype)

llama-util.h

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
#include <string>
1616
#include <vector>
17+
#include <map>
18+
#include <unordered_map>
1719
#include <stdexcept>
1820

1921
#ifdef __has_include
@@ -541,4 +543,166 @@ struct llama_ctx_buffer {
541543
typedef llama_buffer llama_ctx_buffer;
542544
#endif
543545

546+
struct llama_trie_node {
547+
llama_trie_node(): is_terminator(false) {}
548+
549+
std::unordered_map<char, llama_trie_node*> children;
550+
bool is_terminator;
551+
};
552+
553+
// Trie in C++. Creates a Trie out of a list of words. The trie is used to split on multiple delimiters in one pass
554+
// Ported from: https://github.com/huggingface/transformers/blob/ee88ae59940fd4b2c8fc119373143d7a1175c651/src/transformers/tokenization_utils.py#L52
555+
struct llama_trie {
556+
public:
557+
llama_trie(): root_(new llama_trie_node()) {}
558+
559+
void add(const std::string & word) {
560+
if (word.empty()) {
561+
return;
562+
}
563+
564+
llama_trie_node *ref = root_;
565+
for (char c : word) {
566+
if (ref->children.find(c) == ref->children.end()) {
567+
ref->children[c] = new llama_trie_node();
568+
}
569+
ref = ref->children[c];
570+
}
571+
ref->is_terminator = true;
572+
}
573+
574+
// Will look for the words added to the trie within `text`. Output is the boundaries of the words found.
575+
// Note that this trie will match the longest possible word first!
576+
std::vector<int> split(const std::string & text) const {
577+
std::map<int, llama_trie_node*> states;
578+
std::vector<int> offsets{0};
579+
580+
int skip = 0;
581+
for (int current = 0; current < text.size(); current++) {
582+
char current_char = text[current];
583+
if (skip > 0 && current < skip) {
584+
// Prevents the lookahead for matching twice
585+
// like extra_id_100 and id_100
586+
continue;
587+
}
588+
589+
// Whenever we found a match, we need to drop everything
590+
// this is a greedy algorithm, it will match on the first found token
591+
bool reset = false;
592+
593+
// In this case, we already have partial matches (But unfinished)
594+
for (auto state = states.begin(); state != states.end(); ) {
595+
int start = state->first;
596+
llama_trie_node *trie_pointer = state->second;
597+
if (trie_pointer->is_terminator) {
598+
// This is a final match, we need to reset and
599+
// store the results in `offsets`.
600+
601+
// Lookahead to match longest first
602+
// Important in case of extra_id_1 vs extra_id_100
603+
// Here we are also actively looking for other earlier partial
604+
// matches
605+
// "[CLS]", "L", we need to match CLS even if L is special
606+
int end = 0;
607+
for (const auto & look : states) {
608+
int lookstart = look.first;
609+
llama_trie_node *looktrie_pointer = look.second;
610+
int lookahead_index = 0;
611+
if (lookstart > start) {
612+
// This partial match is later, we can stop looking
613+
break;
614+
}
615+
if (lookstart < start) {
616+
// This partial match is earlier, the trie pointer
617+
// was already updated, so index is + 1
618+
lookahead_index = current + 1;
619+
end = current + 1;
620+
} else {
621+
// Here lookstart == start and
622+
// looktrie_pointer == trie_pointer
623+
// It wasn't updated yet so indices are current ones
624+
lookahead_index = current;
625+
end = current;
626+
}
627+
char next_char = lookahead_index < text.size() ? text[lookahead_index] : '\0';
628+
if (looktrie_pointer->is_terminator) {
629+
start = lookstart;
630+
end = lookahead_index;
631+
skip = lookahead_index;
632+
}
633+
634+
auto looktrie_pointer_it = looktrie_pointer->children.find(next_char);
635+
while (looktrie_pointer_it != looktrie_pointer->children.end()) {
636+
looktrie_pointer = looktrie_pointer_it->second;
637+
lookahead_index++;
638+
if (looktrie_pointer->is_terminator) {
639+
start = lookstart;
640+
end = lookahead_index;
641+
skip = lookahead_index;
642+
}
643+
644+
if (lookahead_index == text.size()) {
645+
// End of string
646+
break;
647+
}
648+
next_char = text[lookahead_index];
649+
looktrie_pointer_it = looktrie_pointer->children.find(next_char);
650+
}
651+
}
652+
653+
offsets.push_back(start);
654+
offsets.push_back(end);
655+
reset = true;
656+
break;
657+
}
658+
659+
auto trie_pointer_it = trie_pointer->children.find(current_char);
660+
if (trie_pointer_it != trie_pointer->children.end()) {
661+
// The current character being looked at has a match within the trie
662+
// update the pointer (it will be stored back into states later).
663+
trie_pointer = trie_pointer_it->second;
664+
states[start] = trie_pointer;
665+
++state;
666+
} else {
667+
// The new character has not match in the trie, we need
668+
// to stop keeping track of this partial match.
669+
state = states.erase(state);
670+
}
671+
}
672+
673+
if (reset) {
674+
// Clear the full start (we found a real match)
675+
states.clear();
676+
}
677+
678+
// If this character is a starting character within the trie
679+
// start keeping track of this partial match.
680+
auto children_it = root_->children.find(current_char);
681+
if (current >= skip && children_it != root_->children.end()) {
682+
states[current] = children_it->second;
683+
}
684+
}
685+
686+
// We have a cut at the end with states.
687+
for (const auto & state : states) {
688+
int start = state.first;
689+
llama_trie_node *trie_pointer = state.second;
690+
if (trie_pointer->is_terminator) {
691+
// This is a final match, we need to reset and
692+
// store the results in `offsets`.
693+
int end = text.size();
694+
offsets.push_back(start);
695+
offsets.push_back(end);
696+
break;
697+
}
698+
}
699+
700+
offsets.push_back(text.size());
701+
return offsets;
702+
}
703+
704+
private:
705+
llama_trie_node *root_;
706+
};
707+
544708
#endif

0 commit comments

Comments
 (0)