Skip to content

Commit 8796025

Browse files
committed
Make gguf SpecialVocab vocab size-aware
Update conversion scripts accordingly
1 parent 3a007e2 commit 8796025

10 files changed

+43
-22
lines changed

convert-baichuan-hf-to-gguf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ def parse_args() -> argparse.Namespace:
224224
gguf_writer.add_token_scores(scores)
225225
gguf_writer.add_token_types(toktypes)
226226

227-
special_vocab = gguf.SpecialVocab(dir_model)
227+
special_vocab = gguf.SpecialVocab(dir_model, n_vocab = len(tokens))
228228
special_vocab.add_to_gguf(gguf_writer)
229229

230230
# TENSORS

convert-bloom-hf-to-gguf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def parse_args() -> argparse.Namespace:
129129
gguf_writer.add_token_scores(scores)
130130
gguf_writer.add_token_types(toktypes)
131131

132-
special_vocab = gguf.SpecialVocab(dir_model, load_merges=True)
132+
special_vocab = gguf.SpecialVocab(dir_model, load_merges=True, n_vocab = len(tokens))
133133
special_vocab.add_to_gguf(gguf_writer)
134134

135135
# TENSORS

convert-falcon-hf-to-gguf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def parse_args() -> argparse.Namespace:
145145
gguf_writer.add_token_scores(scores)
146146
gguf_writer.add_token_types(toktypes)
147147

148-
special_vocab = gguf.SpecialVocab(dir_model, load_merges = True)
148+
special_vocab = gguf.SpecialVocab(dir_model, load_merges = True, n_vocab = len(tokens))
149149
special_vocab.add_to_gguf(gguf_writer)
150150

151151
# TENSORS

convert-gptneox-hf-to-gguf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def parse_args() -> argparse.Namespace:
134134
gguf_writer.add_token_scores(scores)
135135
gguf_writer.add_token_types(toktypes)
136136

137-
special_vocab = gguf.SpecialVocab(dir_model, load_merges = True)
137+
special_vocab = gguf.SpecialVocab(dir_model, load_merges = True, n_vocab = len(tokens))
138138
special_vocab.add_to_gguf(gguf_writer)
139139

140140
# TENSORS

convert-llama-ggml-to-gguf.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -388,7 +388,9 @@ def handle_metadata(cfg, hp):
388388
cfg.vocab_dir if cfg.vocab_dir is not None else cfg.model_metadata_dir,
389389
cfg.vocabtype )
390390
# FIXME: Respect cfg.vocab_dir?
391-
svocab = gguf.SpecialVocab(cfg.model_metadata_dir)
391+
svocab = gguf.SpecialVocab(cfg.model_metadata_dir,
392+
load_merges = cfg.vocabtype == 'bpe',
393+
n_vocab = vocab.vocab_size)
392394
convert.check_vocab_size(params, vocab)
393395
return (params, vocab, svocab)
394396

convert-mpt-hf-to-gguf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def parse_args() -> argparse.Namespace:
139139
gguf_writer.add_token_scores(scores)
140140
gguf_writer.add_token_types(toktypes)
141141

142-
special_vocab = gguf.SpecialVocab(dir_model, load_merges = True)
142+
special_vocab = gguf.SpecialVocab(dir_model, load_merges = True, n_vocab = len(tokens))
143143
special_vocab.add_to_gguf(gguf_writer)
144144

145145
# TENSORS

convert-refact-hf-to-gguf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def parse_args() -> argparse.Namespace:
150150
gguf_writer.add_token_scores(scores)
151151
gguf_writer.add_token_types(toktypes)
152152

153-
special_vocab = gguf.SpecialVocab(dir_model, load_merges=True)
153+
special_vocab = gguf.SpecialVocab(dir_model, load_merges=True, n_vocab = len(tokens))
154154
special_vocab.add_to_gguf(gguf_writer)
155155

156156
# TENSORS

convert-starcoder-hf-to-gguf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def parse_args() -> argparse.Namespace:
122122
gguf_writer.add_token_scores(scores)
123123
gguf_writer.add_token_types(toktypes)
124124

125-
special_vocab = gguf.SpecialVocab(dir_model, load_merges = True)
125+
special_vocab = gguf.SpecialVocab(dir_model, load_merges = True, n_vocab = len(tokens))
126126
special_vocab.add_to_gguf(gguf_writer)
127127

128128
# TENSORS

convert.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1159,10 +1159,13 @@ def main(args_in: list[str] | None = None) -> None:
11591159

11601160
vocab: Vocab
11611161
if args.vocab_only:
1162-
assert args.outfile, "need --outfile if using --vocab-only"
1162+
if not args.outfile:
1163+
raise ValueError("need --outfile if using --vocab-only")
11631164
# FIXME: Try to respect vocab_dir somehow?
11641165
vocab = load_vocab(args.vocab_dir or args.model, args.vocabtype)
1165-
special_vocab = gguf.SpecialVocab(model_plus.paths[0].parent, load_merges = args.vocabtype == 'bpe')
1166+
special_vocab = gguf.SpecialVocab(model_plus.paths[0].parent,
1167+
load_merges = args.vocabtype == 'bpe',
1168+
n_vocab = vocab.vocab_size)
11661169
outfile = args.outfile
11671170
OutputFile.write_vocab_only(outfile, params, vocab, special_vocab)
11681171
print(f"Wrote {outfile}")
@@ -1174,7 +1177,9 @@ def main(args_in: list[str] | None = None) -> None:
11741177
vocab_dir = args.vocab_dir if args.vocab_dir else model_plus.paths[0].parent
11751178
vocab = load_vocab(vocab_dir, args.vocabtype)
11761179
# FIXME: Try to respect vocab_dir somehow?
1177-
special_vocab = gguf.SpecialVocab(model_plus.paths[0].parent, load_merges = args.vocabtype == 'bpe')
1180+
special_vocab = gguf.SpecialVocab(model_plus.paths[0].parent,
1181+
load_merges = args.vocabtype == 'bpe',
1182+
n_vocab = vocab.vocab_size)
11781183

11791184
model = model_plus.model
11801185
model = convert_model_names(model, params)

gguf-py/gguf/gguf.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -968,12 +968,15 @@ class SpecialVocab:
968968
merges: list[str] = []
969969
special_token_types: tuple[str, ...] = ('bos', 'eos', 'unk', 'sep', 'pad')
970970
special_token_ids: dict[str, int] = {}
971+
n_vocab: int | None = None
971972

972973
def __init__(
973974
self, path: str | os.PathLike[str], load_merges: bool = False,
974975
special_token_types: tuple[str, ...] | None = None,
976+
n_vocab: int | None = None,
975977
):
976978
self.special_token_ids = {}
979+
self.n_vocab = n_vocab
977980
self.load_merges = load_merges
978981
if special_token_types is not None:
979982
self.special_token_types = special_token_types
@@ -983,6 +986,16 @@ def _load(self, path: Path) -> None:
983986
if not self._try_load_from_tokenizer_json(path):
984987
self._try_load_from_config_json(path)
985988

989+
def _set_special_token(self, typ: str, tid: Any):
990+
if not isinstance(tid, int) or tid < 0:
991+
return
992+
if self.n_vocab is None or tid < self.n_vocab:
993+
self.special_token_ids[typ] = tid
994+
return
995+
print(f'gguf: WARNING: Special token type {typ}, id {tid} out of range, must be under {self.n_vocab} - skipping',
996+
file = sys.stderr)
997+
998+
986999
def _try_load_from_tokenizer_json(self, path: Path) -> bool:
9871000
tokenizer_file = path / 'tokenizer.json'
9881001
if not tokenizer_file.is_file():
@@ -1010,10 +1023,11 @@ def _try_load_from_tokenizer_json(self, path: Path) -> bool:
10101023
tc_content = entry_content
10111024
else:
10121025
continue
1013-
for maybe_token_id in (atok.get('id') for atok in added_tokens if atok.get('content') == tc_content):
1014-
if isinstance(maybe_token_id, int) and maybe_token_id >= 0:
1015-
self.special_token_ids[typ] = maybe_token_id
1016-
break
1026+
# We only need the first match here.
1027+
maybe_token_id = next((
1028+
atok.get('id') for atok in added_tokens
1029+
if atok.get('content') == tc_content), None)
1030+
self._set_special_token(typ, maybe_token_id)
10171031
return True
10181032

10191033
def _try_load_from_config_json(self, path: Path) -> bool:
@@ -1023,21 +1037,21 @@ def _try_load_from_config_json(self, path: Path) -> bool:
10231037
with open(config_file, encoding = 'utf-8') as f:
10241038
config = json.load(f)
10251039
for typ in self.special_token_types:
1026-
maybe_token_id = config.get(f'{typ}_token_id')
1027-
if isinstance(maybe_token_id, int) and maybe_token_id >= 0:
1028-
self.special_token_ids[typ] = maybe_token_id
1040+
self._set_special_token(typ, config.get(f'{typ}_token_id'))
10291041
return True
10301042

1031-
def add_to_gguf(self, gw: GGUFWriter) -> None:
1043+
def add_to_gguf(self, gw: GGUFWriter, quiet: bool = False) -> None:
10321044
if len(self.merges) > 0:
1033-
print(f'gguf: Adding {len(self.merges)} merge(s).')
1045+
if not quiet:
1046+
print(f'gguf: Adding {len(self.merges)} merge(s).')
10341047
gw.add_token_merges(self.merges)
10351048
for typ, tokid in self.special_token_ids.items():
10361049
handler: Callable[[int], None] | None = getattr(gw, f'add_{typ}_token_id', None)
10371050
if handler is None:
1038-
print(f'gguf: WARNING: No handler for special token type {typ} with id {tokid} - skipping')
1051+
print(f'gguf: WARNING: No handler for special token type {typ} with id {tokid} - skipping', file = sys.stderr)
10391052
continue
1040-
print(f'gguf: Setting special token type {typ} to {tokid}')
1053+
if not quiet:
1054+
print(f'gguf: Setting special token type {typ} to {tokid}')
10411055
handler(tokid)
10421056

10431057
def __repr__(self) -> str:

0 commit comments

Comments
 (0)