Skip to content

Commit ea2c33f

Browse files
committed
Better loading of special tokens from jsons
1 parent 6c55fe1 commit ea2c33f

File tree

1 file changed

+33
-5
lines changed

1 file changed

+33
-5
lines changed

convert.py

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -156,8 +156,9 @@ def guessed(model: 'LazyModel', vocab: 'Vocab', file_type: GGMLFileType) -> 'Par
156156

157157

158158
class SentencePieceVocab:
159-
def __init__(self, fname_tokenizer: Path, fname_added_tokens: Optional[Path], fname_special_tokens: Optional[Path]) -> None:
159+
def __init__(self, fname_tokenizer: Path, fname_added_tokens: Optional[Path], fname_special_tokens: Optional[Path], fname_tokenizer_config: Optional[Path]) -> None:
160160
self.sentencepiece_tokenizer = SentencePieceProcessor(str(fname_tokenizer))
161+
161162
added_tokens: Dict[str, int]
162163
if fname_added_tokens is not None:
163164
added_tokens = json.load(open(fname_added_tokens))
@@ -174,13 +175,40 @@ def __init__(self, fname_tokenizer: Path, fname_added_tokens: Optional[Path], fn
174175
self.vocab_size: int = self.vocab_size_base + len(self.added_tokens_list)
175176
self.fname_tokenizer = fname_tokenizer
176177
self.fname_added_tokens = fname_added_tokens
177-
special_tokens: Dict[str, Dict[str, Any]]
178+
self.special_tokens_map: Dict[int, str] = {}
179+
180+
TOKEN_NAME_TO_ID: Dict[str, int] = {
181+
"unk_token": self.sentencepiece_tokenizer.unk_id(),
182+
"bos_token": self.sentencepiece_tokenizer.bos_id(),
183+
"eos_token": self.sentencepiece_tokenizer.eos_id(),
184+
"pad_token": self.sentencepiece_tokenizer.pad_id()
185+
}
186+
187+
tokenizer_config: Dict[str, Any]
188+
if fname_tokenizer_config is not None:
189+
tokenizer_config = json.load(open(fname_tokenizer_config))
190+
else:
191+
tokenizer_config = {}
192+
for key, value in tokenizer_config.items():
193+
assert isinstance(value, dict) or isinstance(value, str)
194+
if key not in TOKEN_NAME_TO_ID or TOKEN_NAME_TO_ID[key] == -1:
195+
continue
196+
self.special_tokens_map[TOKEN_NAME_TO_ID[key]] = value["content"] if isinstance(value, dict) else value
197+
198+
special_tokens: Dict[str, Any]
178199
if fname_special_tokens is not None:
179200
special_tokens = json.load(open(fname_special_tokens))
180201
else:
181202
special_tokens = {}
182-
token_name_to_id = {"unk_token": self.sentencepiece_tokenizer.unk_id(), "bos_token": self.sentencepiece_tokenizer.bos_id(), "eos_token": self.sentencepiece_tokenizer.eos_id(), "pad_token": self.sentencepiece_tokenizer.pad_id()}
183-
self.special_tokens_map = {token_name_to_id[token_name]: info["content"] if isinstance(info, dict) else info for token_name, info in special_tokens.items() if token_name in token_name_to_id and token_name_to_id[token_name] != -1}
203+
for key, value in special_tokens.items():
204+
assert isinstance(value, dict) or isinstance(value, str)
205+
if key not in TOKEN_NAME_TO_ID:
206+
continue
207+
token_id = TOKEN_NAME_TO_ID[key]
208+
if token_id == -1 or token_id in self.special_tokens_map:
209+
continue
210+
self.special_tokens_map[token_id] = value["content"] if isinstance(value, dict) else value
211+
184212
self.vocab_special_size: int = len(self.added_tokens_list) + len(self.special_tokens_map)
185213

186214
def sentencepiece_tokens(self) -> Iterable[Tuple[bytes, float]]:
@@ -1133,7 +1161,7 @@ def load_vocab(path: Path) -> SentencePieceVocab:
11331161
special_tokens_path = path.parent / "special_tokens_map.json"
11341162
tokenizer_config_path = path.parent / "tokenizer_config.json"
11351163
print(f"Loading vocab file {path}")
1136-
return SentencePieceVocab(path, added_tokens_path if added_tokens_path.exists() else None, special_tokens_path if special_tokens_path.exists() else tokenizer_config_path if tokenizer_config_path.exists() else None)
1164+
return SentencePieceVocab(path, added_tokens_path if added_tokens_path.exists() else None, special_tokens_path if special_tokens_path.exists() else None, tokenizer_config_path if tokenizer_config_path.exists() else None)
11371165

11381166

11391167
def default_outfile(model_paths: List[Path], params: Params) -> Path:

0 commit comments

Comments
 (0)