Skip to content

Commit 158be8f

Browse files
committed
gguf.py : some code style changes
1 parent d2b6ca1 commit 158be8f

File tree

2 files changed

+92
-94
lines changed

2 files changed

+92
-94
lines changed

constants.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,32 @@
1-
GGUF_MAGIC = 0x47475546
1+
GGUF_MAGIC = 0x47475546
22
GGUF_VERSION = 1
33

44
# general
5-
KEY_GENERAL_ARCHITECTURE = "general.architecture"
5+
KEY_GENERAL_ARCHITECTURE = "general.architecture"
66
KEY_GENERAL_QUANTIZATION_VERSION = "general.quantization_version"
7-
KEY_GENERAL_NAME = "general.name"
8-
KEY_GENERAL_AUTHOR = "general.author"
9-
KEY_GENERAL_URL = "general.url"
10-
KEY_GENERAL_DESCRIPTION = "general.description"
11-
KEY_GENERAL_FILE_TYPE = "general.file_type"
12-
KEY_GENERAL_LICENSE = "general.license"
13-
KEY_GENERAL_SOURCE_URL = "general.source.url"
14-
KEY_GENERAL_SOURCE_HF_REPO = "general.source.hugginface.repository"
7+
KEY_GENERAL_NAME = "general.name"
8+
KEY_GENERAL_AUTHOR = "general.author"
9+
KEY_GENERAL_URL = "general.url"
10+
KEY_GENERAL_DESCRIPTION = "general.description"
11+
KEY_GENERAL_FILE_TYPE = "general.file_type"
12+
KEY_GENERAL_LICENSE = "general.license"
13+
KEY_GENERAL_SOURCE_URL = "general.source.url"
14+
KEY_GENERAL_SOURCE_HF_REPO = "general.source.hugginface.repository"
1515

1616
# LLM
17-
KEY_LLM_CONTEXT_LENGTH = "{llm}.context_length"
18-
KEY_LLM_EMBEDDING_LENGTH = "{llm}.embedding_length"
19-
KEY_LLM_LAYER_COUNT = "{llm}.layer_count"
20-
KEY_LLM_FEED_FORWARD_LENGTH = "{llm}.feed_forward_length"
21-
KEY_LLM_USE_PARALLEL_RESIDUAL = "{llm}.use_parallel_residual"
22-
KEY_LLM_TENSOR_DATA_LAYOUT = "{llm}.tensor_data_layout"
17+
KEY_LLM_CONTEXT_LENGTH = "{llm}.context_length"
18+
KEY_LLM_EMBEDDING_LENGTH = "{llm}.embedding_length"
19+
KEY_LLM_LAYER_COUNT = "{llm}.layer_count"
20+
KEY_LLM_FEED_FORWARD_LENGTH = "{llm}.feed_forward_length"
21+
KEY_LLM_USE_PARALLEL_RESIDUAL = "{llm}.use_parallel_residual"
22+
KEY_LLM_TENSOR_DATA_LAYOUT = "{llm}.tensor_data_layout"
2323

2424
# attention
25-
KEY_ATTENTION_HEAD_COUNT = "{llm}.attention.head_count"
26-
KEY_ATTENTION_HEAD_COUNT_KV = "{llm}.attention.head_count_kv"
27-
KEY_ATTENTION_MAX_ALIBI_BIAS = "{llm}.attention.max_alibi_bias"
28-
KEY_ATTENTION_CLAMP_KQV = "{llm}.attention.clamp_kqv"
25+
KEY_ATTENTION_HEAD_COUNT = "{llm}.attention.head_count"
26+
KEY_ATTENTION_HEAD_COUNT_KV = "{llm}.attention.head_count_kv"
27+
KEY_ATTENTION_MAX_ALIBI_BIAS = "{llm}.attention.max_alibi_bias"
28+
KEY_ATTENTION_CLAMP_KQV = "{llm}.attention.clamp_kqv"
2929

3030
# RoPE
31-
KEY_ROPE_DIMENSION_COUNT = "{llm}.rope.dimension_count"
32-
KEY_ROPE_SCALE = "{llm}.rope.scale"
31+
KEY_ROPE_DIMENSION_COUNT = "{llm}.rope.dimension_count"
32+
KEY_ROPE_SCALE = "{llm}.rope.scale"

gguf.py

Lines changed: 70 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,13 @@
66
"""
77

88
import struct
9+
import constants
910
from enum import IntEnum
1011
from typing import List, Any
11-
import constants
12-
1312

1413
class GGMLQuantizationType(IntEnum):
15-
F32 = 0
16-
F16 = 1
14+
F32 = 0
15+
F16 = 1
1716
QR_0 = 2
1817
Q4_1 = 3
1918
# Q4_2 = 4 # support has been removed
@@ -31,31 +30,30 @@ class GGMLQuantizationType(IntEnum):
3130

3231

3332
class GGUFValueType(IntEnum):
34-
UINT8 = 0
35-
INT8 = 1
36-
UINT16 = 2
37-
INT16 = 3
38-
UINT32 = 4
39-
INT32 = 5
33+
UINT8 = 0
34+
INT8 = 1
35+
UINT16 = 2
36+
INT16 = 3
37+
UINT32 = 4
38+
INT32 = 5
4039
FLOAT32 = 6
41-
BOOL = 7
42-
STRING = 8
43-
ARRAY = 9
40+
BOOL = 7
41+
STRING = 8
42+
ARRAY = 9
4443

4544
@staticmethod
46-
def get_type(value):
47-
if isinstance(value, str):
45+
def get_type(val):
46+
if isinstance(val, str):
4847
return GGUFValueType.STRING
49-
elif isinstance(value, list):
48+
elif isinstance(val, list):
5049
return GGUFValueType.ARRAY
51-
elif isinstance(value, float):
50+
elif isinstance(val, float):
5251
return GGUFValueType.FLOAT32
53-
elif isinstance(value, bool):
52+
elif isinstance(val, bool):
5453
return GGUFValueType.BOOL
5554
else:
5655
return GGUFValueType.INT32
5756

58-
5957
class GGUFWriter:
6058
def __init__(self, buffered_writer):
6159
self.buffered_writer = buffered_writer
@@ -72,81 +70,81 @@ def open(cls, path: str) -> "GGUFWriter":
7270
return cls(f)
7371

7472
def write_key(self, key: str):
75-
self.write_value(key, GGUFValueType.STRING)
73+
self.write_val(key, GGUFValueType.STRING)
7674

77-
def write_uint8(self, key: str, value: int):
75+
def write_uint8(self, key: str, val: int):
7876
self.write_key(key)
79-
self.write_value(value, GGUFValueType.UINT8)
77+
self.write_val(val, GGUFValueType.UINT8)
8078

81-
def write_int8(self, key: str, value: int):
79+
def write_int8(self, key: str, val: int):
8280
self.write_key(key)
83-
self.write_value(value, GGUFValueType.INT8)
81+
self.write_val(val, GGUFValueType.INT8)
8482

85-
def write_uint16(self, key: str, value: int):
83+
def write_uint16(self, key: str, val: int):
8684
self.write_key(key)
87-
self.write_value(value, GGUFValueType.UINT16)
85+
self.write_val(val, GGUFValueType.UINT16)
8886

89-
def write_int16(self, key: str, value: int):
87+
def write_int16(self, key: str, val: int):
9088
self.write_key(key)
91-
self.write_value(value, GGUFValueType.INT16)
89+
self.write_val(val, GGUFValueType.INT16)
9290

93-
def write_uint32(self, key: str, value: int):
91+
def write_uint32(self, key: str, val: int):
9492
self.write_key(key)
95-
self.write(value, GGUFValueType.UINT32)
93+
self.write_val(val, GGUFValueType.UINT32)
9694

97-
def write_int32(self, key: str, value: int):
95+
def write_int32(self, key: str, val: int):
9896
self.write_key(key)
99-
self.write_value(value, GGUFValueType.INT32)
97+
self.write_val(val, GGUFValueType.INT32)
10098

101-
def write_float32(self, key: str, value: float):
99+
def write_float32(self, key: str, val: float):
102100
self.write_key(key)
103-
self.write_value(value, GGUFValueType.FLOAT32)
101+
self.write_val(val, GGUFValueType.FLOAT32)
104102

105-
def write_bool(self, key: str, value: bool):
103+
def write_bool(self, key: str, val: bool):
106104
self.write_key(key)
107-
self.write_value(value, GGUFValueType.BOOL)
105+
self.write_val(val, GGUFValueType.BOOL)
108106

109-
def write_string(self, key: str, value: str):
107+
def write_string(self, key: str, val: str):
110108
self.write_key(key)
111-
self.write_value(value, GGUFValueType.STRING)
109+
self.write_val(val, GGUFValueType.STRING)
112110

113-
def write_array(self, key: str, value: list):
114-
if not isinstance(value, list):
111+
def write_array(self, key: str, val: list):
112+
if not isinstance(val, list):
115113
raise ValueError("Value must be a list for array type")
116114

117115
self.write_key(key)
118-
self.write_value(value, GGUFValueType.ARRAY)
119-
120-
def write_value(self: str, value: Any, value_type: GGUFValueType = None):
121-
if value_type is None:
122-
value_type = GGUFValueType.get_type(value)
123-
124-
self.buffered_writer.write(struct.pack("<I", value_type))
125-
126-
if value_type == GGUFValueType.UINT8:
127-
self.buffered_writer.write(struct.pack("<B", value))
128-
elif value_type == GGUFValueType.INT8:
129-
self.buffered_writer.write(struct.pack("<b", value))
130-
elif value_type == GGUFValueType.UINT16:
131-
self.buffered_writer.write(struct.pack("<H", value))
132-
elif value_type == GGUFValueType.INT16:
133-
self.buffered_writer.write(struct.pack("<h", value))
134-
elif value_type == GGUFValueType.UINT32:
135-
self.buffered_writer.write(struct.pack("<I", value))
136-
elif value_type == GGUFValueType.INT32:
137-
self.buffered_writer.write(struct.pack("<i", value))
138-
elif value_type == GGUFValueType.FLOAT32:
139-
self.buffered_writer.write(struct.pack("<f", value))
140-
elif value_type == GGUFValueType.BOOL:
141-
self.buffered_writer.write(struct.pack("?", value))
142-
elif value_type == GGUFValueType.STRING:
143-
encoded_value = value.encode("utf8")
144-
self.buffered_writer.write(struct.pack("<I", len(encoded_value)))
145-
self.buffered_writer.write(encoded_value)
146-
elif value_type == GGUFValueType.ARRAY:
147-
self.buffered_writer.write(struct.pack("<I", len(value)))
148-
for item in value:
149-
self.write_value(item)
116+
self.write_val(val, GGUFValueType.ARRAY)
117+
118+
def write_val(self: str, val: Any, vtype: GGUFValueType = None):
119+
if vtype is None:
120+
vtype = GGUFValueType.get_type(val)
121+
122+
self.buffered_writer.write(struct.pack("<I", vtype))
123+
124+
if vtype == GGUFValueType.UINT8:
125+
self.buffered_writer.write(struct.pack("<B", val))
126+
elif vtype == GGUFValueType.INT8:
127+
self.buffered_writer.write(struct.pack("<b", val))
128+
elif vtype == GGUFValueType.UINT16:
129+
self.buffered_writer.write(struct.pack("<H", val))
130+
elif vtype == GGUFValueType.INT16:
131+
self.buffered_writer.write(struct.pack("<h", val))
132+
elif vtype == GGUFValueType.UINT32:
133+
self.buffered_writer.write(struct.pack("<I", val))
134+
elif vtype == GGUFValueType.INT32:
135+
self.buffered_writer.write(struct.pack("<i", val))
136+
elif vtype == GGUFValueType.FLOAT32:
137+
self.buffered_writer.write(struct.pack("<f", val))
138+
elif vtype == GGUFValueType.BOOL:
139+
self.buffered_writer.write(struct.pack("?", val))
140+
elif vtype == GGUFValueType.STRING:
141+
encoded_val = val.encode("utf8")
142+
self.buffered_writer.write(struct.pack("<I", len(encoded_val)))
143+
self.buffered_writer.write(encoded_val)
144+
elif vtype == GGUFValueType.ARRAY:
145+
self.buffered_writer.write(struct.pack("<I", len(val)))
146+
for item in val:
147+
self.write_val(item)
150148
else:
151149
raise ValueError("Invalid GGUF metadata value type")
152150

0 commit comments

Comments
 (0)