Skip to content

Commit 7e4a4eb

Browse files
committed
refactor: Enhance readability, functionality, and code quality
- Improved code formatting and readability for better maintainability. - Refactored LazyUnpickler's CLASSES dictionary for clarity. - Added print statements and warnings in check_vocab_size for user feedback. - Removed find_vocab_file_path, as it's superseded by VocabFactory. - Preparatory changes for upcoming classes: OutputFile and VocabFactory. - Overall focus on code quality, error handling, and consistency. These changes reflect a continuous effort to refine the codebase, ensuring it meets best practices and prepares for future enhancements, such as the VocabFactory.
1 parent db4b8ac commit 7e4a4eb

File tree

1 file changed

+26
-27
lines changed

1 file changed

+26
-27
lines changed

convert.py

Lines changed: 26 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -868,13 +868,17 @@ def rebuild_from_type_v2(func, new_type, args, state):
868868
CLASSES: dict[tuple[str, str], Any] = {
869869
# getattr used here as a workaround for mypy not being smart enough to determine
870870
# the staticmethods have a __func__ attribute.
871-
('torch._tensor', '_rebuild_from_type_v2'): getattr(rebuild_from_type_v2, '__func__'),
872-
('torch._utils', '_rebuild_tensor_v2'): getattr(lazy_rebuild_tensor_v2, '__func__'),
873-
('torch', 'BFloat16Storage'): LazyStorageKind(DT_BF16),
874-
('torch', 'HalfStorage'): LazyStorageKind(DT_F16),
875-
('torch', 'FloatStorage'): LazyStorageKind(DT_F32),
876-
('torch', 'IntStorage'): LazyStorageKind(DT_I32),
877-
('torch', 'Tensor'): LazyTensor,
871+
("torch._tensor", "_rebuild_from_type_v2"): getattr(
872+
rebuild_from_type_v2, "__func__"
873+
),
874+
("torch._utils", "_rebuild_tensor_v2"): getattr(
875+
lazy_rebuild_tensor_v2, "__func__"
876+
),
877+
("torch", "BFloat16Storage"): LazyStorageKind(DT_BF16),
878+
("torch", "HalfStorage"): LazyStorageKind(DT_F16),
879+
("torch", "FloatStorage"): LazyStorageKind(DT_F32),
880+
("torch", "IntStorage"): LazyStorageKind(DT_I32),
881+
("torch", "Tensor"): LazyTensor,
878882
}
879883

880884
def find_class(self, module: str, name: str) -> Any:
@@ -985,24 +989,32 @@ def bounded_parallel_map(func: Callable[[In], Out], iterable: Iterable[In], conc
985989
def check_vocab_size(params: Params, vocab: Vocab, pad_vocab: bool = False) -> None:
986990
if params.n_vocab != vocab.vocab_size:
987991
if params.n_vocab == vocab.vocab_size:
988-
print("Ignoring added_tokens.json since model matches vocab size without it.")
989-
vocab.added_tokens_dict = OrderedDict()
990-
vocab.vocab_size = vocab.vocab_size
992+
print(
993+
"Ignoring added_tokens.json since model matches vocab size without it."
994+
)
991995
return
992-
993996
if pad_vocab and params.n_vocab > vocab.vocab_size:
994997
pad_count = params.n_vocab - vocab.vocab_size
995-
print(f'Padding vocab with {pad_count} token(s) - <dummy00001> through <dummy{pad_count:05}>')
998+
print(
999+
f"Padding vocab with {pad_count} token(s) - <dummy00001> through <dummy{pad_count:05}>"
1000+
)
9961001
for i in range(1, (params.n_vocab - vocab.vocab_size) + 1):
997-
vocab.added_tokens_dict[f'<dummy{i:05}>'] = -1
1002+
vocab.added_tokens_dict[f"<dummy{i:05}>"] = -1
9981003
vocab.vocab_size = params.n_vocab
9991004
return
10001005
msg = f"Vocab size mismatch (model has {params.n_vocab}, but {vocab.fname_tokenizer}"
10011006
msg += f" has {vocab.vocab_size})."
10021007
if vocab.vocab_size < params.n_vocab < vocab.vocab_size + 20:
10031008
msg += f" Most likely you are missing added_tokens.json (should be in {vocab.fname_tokenizer.parent})."
10041009
if vocab.vocab_size < params.n_vocab:
1005-
msg += " Possibly try using the --padvocab option."
1010+
msg += " Add the --pad-vocab option and try again."
1011+
1012+
# Check if params.n_vocab is -1 and issue a warning
1013+
if params.n_vocab == -1:
1014+
warnings.warn(
1015+
"WARNING: The model's vocab size is set to -1 in params.json. Please update it manually."
1016+
)
1017+
10061018
raise Exception(msg)
10071019

10081020

@@ -1289,19 +1301,6 @@ def load_some_model(path: Path) -> ModelPlus:
12891301
return model_plus
12901302

12911303

1292-
def find_vocab_file_path(path: Path, vocab_file: str) -> Optional[Path]:
1293-
path2 = path / vocab_file
1294-
# Use `.parent` instead of /.. to handle the symlink case better.
1295-
path3 = path.parent / vocab_file
1296-
1297-
if path2.exists():
1298-
return path2
1299-
if path3.exists():
1300-
return path3
1301-
1302-
return None
1303-
1304-
13051304
def default_outfile(model_paths: list[Path], file_type: GGMLFileType) -> Path:
13061305
namestr = {
13071306
GGMLFileType.AllF32: "f32",

0 commit comments

Comments
 (0)