Skip to content

Commit 0a6d5ad

Browse files
committed
Reuse definitions from convert.py
1 parent 63da54e commit 0a6d5ad

File tree

1 file changed

+7
-39
lines changed

1 file changed

+7
-39
lines changed

convert-lora-to-ggml.py

Lines changed: 7 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -3,45 +3,11 @@
33
import re
44
import struct
55
import sys
6-
from dataclasses import dataclass
7-
from typing import Any, Sequence
6+
from typing import Any, Dict, Sequence, TextIO
87

9-
import numpy as np
108
import torch
119

12-
13-
# TODO: import this from convert.py once #545 is merged
14-
@dataclass(frozen=True)
15-
class UnquantizedDataType:
16-
name: str
17-
18-
19-
DT_F16 = UnquantizedDataType("F16")
20-
DT_F32 = UnquantizedDataType("F32")
21-
22-
23-
@dataclass(frozen=True)
24-
class QuantizedDataType:
25-
groupsize: int
26-
have_addends: bool
27-
have_g_idx: bool
28-
29-
30-
DataType = UnquantizedDataType
31-
32-
DATA_TYPE_TO_FTYPE: dict[DataType, int] = {
33-
DT_F32: 0,
34-
DT_F16: 1,
35-
}
36-
37-
DATA_TYPE_TO_NUMPY: dict[DataType, np.dtype[Any]] = {
38-
DT_F16: np.dtype(np.float16),
39-
DT_F32: np.dtype(np.float32),
40-
}
41-
42-
NUMPY_TYPE_TO_DATA_TYPE: dict[np.dtype[Any], DataType] = {
43-
dtype: data_type for (data_type, dtype) in DATA_TYPE_TO_NUMPY.items()
44-
}
10+
from convert import DATA_TYPE_TO_FTYPE, NUMPY_TYPE_TO_DATA_TYPE, DataType
4511

4612
HF_SUBLAYER_TO_GGML = {
4713
"self_attn.q_proj": "attention.wq",
@@ -59,7 +25,7 @@ class QuantizedDataType:
5925
}
6026

6127

62-
def translate_tensor_name(t):
28+
def translate_tensor_name(t: str) -> str:
6329
match = re.match(r".*layers\.(\d+)\.(\w+\.\w+)\.lora_(A|B)\.weight", t)
6430
if match:
6531
nn = match.group(1)
@@ -80,13 +46,15 @@ def translate_tensor_name(t):
8046
sys.exit(1)
8147

8248

83-
def write_file_header(fout, params):
49+
def write_file_header(fout: TextIO, params: Dict[str, Any]) -> None:
8450
fout.write(b"ggla"[::-1]) # magic (ggml lora)
8551
fout.write(struct.pack("i", 1)) # file version
8652
fout.write(struct.pack("ii", params["r"], params["lora_alpha"]))
8753

8854

89-
def write_tensor_header(self, name: str, shape: Sequence[int], data_type: 1) -> None:
55+
def write_tensor_header(
56+
self, name: str, shape: Sequence[int], data_type: DataType
57+
) -> None:
9058
sname = name.encode("utf-8")
9159
fout.write(
9260
struct.pack(

0 commit comments

Comments
 (0)