|
1 |
| -# HF llama --> gguf conversion, GQA/70b not supported |
| 1 | +# HF llama --> gguf conversion |
2 | 2 |
|
3 | 3 | import gguf
|
4 | 4 | import gguf_namemap as tmap
|
|
10 | 10 | import numpy as np
|
11 | 11 | import torch
|
12 | 12 |
|
13 |
| -from typing import Any, List |
| 13 | +from typing import Any, List, Optional |
14 | 14 | from pathlib import Path
|
15 | 15 | from sentencepiece import SentencePieceProcessor
|
16 | 16 |
|
17 | 17 | #NDArray = np.ndarray[Any, Any]
|
18 | 18 | # compatible with python < 3.9
|
19 | 19 | NDArray: 'TypeAlias' = 'np.ndarray[Any, Any]'
|
20 | 20 |
|
21 |
| - |
22 |
| -def permute(weights: NDArray, n_head: int) -> NDArray: |
| 21 | +def permute(weights: NDArray, n_head: int, n_kv_head: Optional[int] = None) -> NDArray: |
| 22 | + if n_kv_head is not None and n_head != n_kv_head: n_head //= n_kv_head |
23 | 23 | return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
|
24 |
| - .swapaxes(1, 2) |
25 |
| - .reshape(weights.shape)) |
| 24 | + .swapaxes(1, 2) |
| 25 | + .reshape(weights.shape)) |
26 | 26 |
|
27 | 27 | def count_model_parts(dir_model: str) -> int:
|
28 | 28 | num_parts = 0
|
@@ -220,7 +220,7 @@ def count_model_parts(dir_model: str) -> int:
|
220 | 220 |
|
221 | 221 | # permute these
|
222 | 222 | if name.endswith(".q_proj.weight") or name.endswith(".k_proj.weight"):
|
223 |
| - data = permute(data,head_count) |
| 223 | + data = permute(data, head_count, head_count_kv) |
224 | 224 |
|
225 | 225 | # map tensor names
|
226 | 226 | if name.endswith(".weight") and name[:-7] in tensor_map:
|
@@ -289,7 +289,7 @@ def count_model_parts(dir_model: str) -> int:
|
289 | 289 |
|
290 | 290 | # permute these
|
291 | 291 | if name.endswith(".q_proj.weight") or name.endswith(".k_proj.weight"):
|
292 |
| - data = permute(data, head_count) |
| 292 | + data = permute(data, head_count, head_count_kv) |
293 | 293 |
|
294 | 294 | # map tensor names
|
295 | 295 | if name.endswith(".weight") and name[:-7] in tensor_map:
|
|
0 commit comments