Skip to content

Commit 8a72b3d

Browse files
committed
Address review comments by foldl.
1 parent 098c14e commit 8a72b3d

File tree

1 file changed

+9
-23
lines changed

1 file changed

+9
-23
lines changed

convert_grok.py

Lines changed: 9 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ def dump_state_dict(f, ggml_type, input_dir, config):
204204
# First operate on meta tensors to find shapes and dtypes for GGUF header.
205205
for idx, name in enumerate(weight_names):
206206
weight, scales = get_weights(f"{input_dir}/tensor{idx:05}_000")
207-
meta_tensor = convert_weight(name, weight, scales, config.experts, device="meta")
207+
meta_tensor = convert_weight(name, weight, scales, config, device="meta")
208208
dtype, tensor_ggml_type = get_dtype_and_ggml_type(meta_tensor, ggml_type)
209209
quantized_meta_tensor = maybe_quantize_tensor(meta_tensor, tensor_ggml_type)
210210
f.add_tensor_info(
@@ -226,8 +226,7 @@ def dump_state_dict(f, ggml_type, input_dir, config):
226226

227227
for name in weight_names:
228228
weight, scales = weights.pop(name)
229-
tensor = convert_weight(name, weight, scales, config.experts)
230-
tensor = maybe_permute_tensor(name, tensor, config)
229+
tensor = convert_weight(name, weight, scales, config)
231230
_, tensor_ggml_type = get_dtype_and_ggml_type(tensor, ggml_type)
232231
array = maybe_quantize_tensor(tensor, tensor_ggml_type).numpy()
233232

@@ -258,7 +257,7 @@ def from_numpy(array):
258257
return torch.from_numpy(array)
259258

260259

261-
def convert_weight(name, weight, scales, experts, dtype=torch.float32, device=None):
260+
def convert_weight(name, weight, scales, config, dtype=torch.float32, device=None):
262261
# copied from https://gist.github.com/chu-tianxiang/ec310e15d56949fd0f351cb5f65ee7a1
263262
weight = from_numpy(weight).to(device=device, dtype=dtype)
264263
if scales is not None:
@@ -271,32 +270,19 @@ def convert_weight(name, weight, scales, experts, dtype=torch.float32, device=No
271270
else:
272271
weight = weight * scale
273272

274-
# Transpose linear matrix
275-
if len(weight.shape) >= 2 and "token_embd" not in name:
273+
if name == "token_embd":
274+
weight *= config.embedding_multiplier_scale
275+
elif len(weight.shape) >= 2:
276+
# Transpose linear matrix
276277
weight = weight.transpose(-1, -2)
277278

279+
278280
if name.endswith("ffn_gate_inp") or name.endswith("_exps"):
279-
weight = weight[experts] # gather.
281+
weight = weight[config.experts] # gather.
280282

281283
return weight
282284

283285

284-
def maybe_permute_tensor(name, tensor, config):
285-
def permute(weights, n_head):
286-
return (
287-
weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
288-
.swapaxes(1, 2)
289-
.reshape(weights.shape)
290-
)
291-
292-
if name.endswith("attn_k"):
293-
return permute(tensor, config.num_key_value_heads)
294-
elif name.endswith("attn_q"):
295-
return permute(tensor, config.num_attention_heads)
296-
297-
return tensor
298-
299-
300286
def extract_vocabulary_from_model(vocab):
301287
tokens = []
302288
scores = []

0 commit comments

Comments
 (0)