Skip to content

Commit bc78bf4

Browse files
committed
convert-hf : faster model parts loading
Instead of pre-loading them all into a dict, iterate on the tensors in the model parts progressively as needed in Model.write_tensors Conversion for some architectures relies on checking for the presence of specific tensor names, so for multi-part models, the weight map is read from the relevant json file to quickly get these names up-front.
1 parent 98db434 commit bc78bf4

File tree

1 file changed

+36
-10
lines changed

1 file changed

+36
-10
lines changed

convert-hf-to-gguf.py

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -52,13 +52,14 @@ class Model:
5252
is_big_endian: bool
5353
endianess: gguf.GGUFEndian
5454
use_temp_file: bool
55+
lazy: bool
5556
part_names: list[str]
5657
is_safetensors: bool
5758
hparams: dict[str, Any]
5859
gguf_writer: gguf.GGUFWriter
5960
block_count: int
6061
tensor_map: gguf.TensorNameMap
61-
tensors: dict[str, Tensor]
62+
tensor_names: set[str] | None
6263

6364
# subclasses should define this!
6465
model_arch: gguf.MODEL_ARCH
@@ -72,6 +73,7 @@ def __init__(self, dir_model: Path, ftype: int, fname_out: Path, is_big_endian:
7273
self.is_big_endian = is_big_endian
7374
self.endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE
7475
self.use_temp_file = use_temp_file
76+
self.lazy = not eager
7577
self.part_names = Model.get_model_part_names(self.dir_model, ".safetensors")
7678
self.is_safetensors = len(self.part_names) > 0
7779
if not self.is_safetensors:
@@ -80,10 +82,7 @@ def __init__(self, dir_model: Path, ftype: int, fname_out: Path, is_big_endian:
8082
self.gguf_writer = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file)
8183
self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer"])
8284
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
83-
self.tensors = dict(self.get_tensors())
84-
if not eager:
85-
for k, v in self.tensors.items():
86-
self.tensors[k] = LazyTorchTensor.from_eager(v)
85+
self.tensor_names = None
8786

8887
@classmethod
8988
def __init_subclass__(cls):
@@ -104,6 +103,22 @@ def set_vocab(self):
104103
self._set_vocab_gpt2()
105104

106105
def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
106+
tensor_names_from_parts: set[str] = set()
107+
108+
if len(self.part_names) > 1:
109+
self.tensor_names = set()
110+
index_name = "model.safetensors" if self.is_safetensors else "pytorch_model.bin"
111+
index_name += ".index.json"
112+
logger.info(f"gguf: loading model weight map from '{index_name}'")
113+
with open(self.dir_model / index_name, "r", encoding="utf-8") as f:
114+
index: dict[str, Any] = json.load(f)
115+
weight_map = index.get("weight_map")
116+
if weight_map is None or not isinstance(weight_map, dict):
117+
raise ValueError(f"Can't load 'weight_map' from {index_name!r}")
118+
self.tensor_names.update(weight_map.keys())
119+
else:
120+
self.tensor_names = tensor_names_from_parts
121+
107122
for part_name in self.part_names:
108123
logger.info(f"gguf: loading model part '{part_name}'")
109124
ctx: ContextManager[Any]
@@ -114,10 +129,18 @@ def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
114129
ctx = contextlib.nullcontext(torch.load(str(self.dir_model / part_name), map_location="cpu", mmap=True, weights_only=True))
115130

116131
with ctx as model_part:
132+
tensor_names_from_parts.update(model_part.keys())
133+
117134
for name in model_part.keys():
118135
data = model_part.get_tensor(name) if self.is_safetensors else model_part[name]
136+
if self.lazy:
137+
data = LazyTorchTensor.from_eager(data)
119138
yield name, data
120139

140+
# only verify tensor name presence; it doesn't matter if they are not in the right files
141+
if len(sym_diff := tensor_names_from_parts.symmetric_difference(self.tensor_names)) > 0:
142+
raise ValueError(f"Mismatch between weight map and model parts for tensor names: {sym_diff}")
143+
121144
def format_tensor_name(self, key: gguf.MODEL_TENSOR, bid: int | None = None, suffix: str = ".weight") -> str:
122145
name: str = gguf.TENSOR_NAMES[key]
123146
if key not in gguf.MODEL_TENSORS[self.model_arch]:
@@ -194,7 +217,7 @@ def extra_f16_tensors(self, name: str, new_name: str, bid: int | None, n_dims: i
194217
def write_tensors(self):
195218
max_name_len = max(len(s) for _, s in self.tensor_map.mapping.values()) + len(".weight,")
196219

197-
for name, data_torch in self.tensors.items():
220+
for name, data_torch in self.get_tensors():
198221
# we don't need these
199222
if name.endswith((".attention.masked_bias", ".attention.bias", ".rotary_emb.inv_freq")):
200223
continue
@@ -248,8 +271,6 @@ def write_tensors(self):
248271

249272
def write(self):
250273
self.write_tensors()
251-
self.tensors.clear() # save memory by not keeping references to the tensors
252-
253274
self.gguf_writer.write_header_to_file()
254275
self.gguf_writer.write_kv_data_to_file()
255276
self.gguf_writer.write_tensors_to_file(progress=True)
@@ -621,8 +642,10 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
621642
tensors.append((self.map_tensor_name(name), data_torch))
622643

623644
if name == "word_embeddings.weight":
645+
assert self.tensor_names is not None
646+
624647
# TODO: tie them at runtime, don't duplicate in the model file
625-
if "lm_head.weight" not in self.tensors and "output.weight" not in self.tensors:
648+
if all(s not in self.tensor_names for s in ("lm_head.weight", "output.weight")):
626649
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT), data_torch))
627650

628651
return tensors
@@ -1759,7 +1782,10 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
17591782
tensors: list[tuple[str, Tensor]] = [(new_name, data_torch)]
17601783

17611784
if new_name == self.format_tensor_name(gguf.MODEL_TENSOR.TOKEN_EMBD):
1762-
if "lm_head.weight" not in self.tensors and "output.weight" not in self.tensors:
1785+
assert self.tensor_names is not None
1786+
1787+
if all(s not in self.tensor_names for s in ("lm_head.weight", "output.weight")):
1788+
# copy tok_embd.weight to output.weight
17631789
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT), data_torch))
17641790

17651791
return tensors

0 commit comments

Comments
 (0)