Skip to content

Commit 0a1ef11

Browse files
committed
Write tensors in layer order.
1 parent 60b29ea commit 0a1ef11

File tree

1 file changed

+21
-10
lines changed

1 file changed

+21
-10
lines changed

convert_grok.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -197,12 +197,20 @@ def get_dtype_and_ggml_type(name, tensor, ggml_type):
197197

198198

199199
def dump_state_dict(f, ggml_type, input_dir, config):
200-
weight_names = get_weight_names(config.num_hidden_layers)
201200
weights = {}
202201

203-
# First operate on meta tensors to find shapes and dtypes for GGUF header.
204-
for idx, name in enumerate(weight_names):
205-
weight, scales = get_weights(f"{input_dir}/tensor{idx:05}_000")
202+
# Load weights in file order (mmap'ed).
203+
for idx, name in enumerate(get_weight_names(config.num_hidden_layers)):
204+
weights[name] = get_weights(f"{input_dir}/tensor{idx:05}_000")
205+
206+
logging.debug("Loaded %i files", len(weights))
207+
208+
# But write in layer order.
209+
weight_names = get_weight_names(config.num_hidden_layers, lexicographic=False)
210+
211+
# Operate on meta tensors to find shapes and dtypes for GGUF header.
212+
for name in weight_names:
213+
weight, scales = weights[name]
206214
meta_tensor = convert_weight(name, weight, scales, config, device="meta")
207215
dtype, tensor_ggml_type = get_dtype_and_ggml_type(name, meta_tensor, ggml_type)
208216
quantized_meta_tensor = maybe_quantize_tensor(meta_tensor, tensor_ggml_type)
@@ -213,8 +221,6 @@ def dump_state_dict(f, ggml_type, input_dir, config):
213221
quantized_meta_tensor.nbytes,
214222
tensor_ggml_type,
215223
)
216-
weights[name] = weight, scales
217-
logging.debug("Loaded %i files", len(weight_names))
218224

219225
f.write_header_to_file()
220226
f.write_kv_data_to_file()
@@ -244,7 +250,7 @@ def dump_state_dict(f, ggml_type, input_dir, config):
244250
except NameError:
245251
pass
246252

247-
if len(tensor_info) != len(weight_names):
253+
if weights:
248254
logging.warning("Not all tensors are converted")
249255

250256

@@ -293,8 +299,10 @@ def extract_vocabulary_from_model(vocab):
293299
return tokens, scores, toktypes
294300

295301

296-
def get_weight_names(num_hidden_layers=64):
297-
"""Return Grok-1 weight names, in the order in which they are in the tensor#####_000 files."""
302+
def get_weight_names(num_hidden_layers=64, lexicographic=True):
303+
"""Return Grok-1 weight names.
304+
305+
If `lexicographic` is set, the order is as in the tensor#####_000 files."""
298306

299307
weight_names = [
300308
gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.TOKEN_EMBD],
@@ -317,7 +325,10 @@ def get_weight_names(num_hidden_layers=64):
317325
)
318326

319327
layers = [str(bid) for bid in range(64)]
320-
layers.sort() # Lexicographic sort: 0 < 1 < 10 < 11 ... < 2 < 20 < ...
328+
329+
if lexicographic:
330+
# Lexicographic sort: 0 < 1 < 10 < 11 ... < 2 < 20 < ...
331+
layers.sort()
321332

322333
for bid in layers[:num_hidden_layers]:
323334
for key in layer:

0 commit comments

Comments
 (0)