6
6
"""
7
7
8
8
import struct
9
+ import constants
9
10
from enum import IntEnum
10
11
from typing import List , Any
11
- import constants
12
-
13
12
14
13
class GGMLQuantizationType (IntEnum ):
15
- F32 = 0
16
- F16 = 1
14
+ F32 = 0
15
+ F16 = 1
17
16
QR_0 = 2
18
17
Q4_1 = 3
19
18
# Q4_2 = 4 # support has been removed
@@ -31,31 +30,30 @@ class GGMLQuantizationType(IntEnum):
31
30
32
31
33
32
class GGUFValueType (IntEnum ):
34
- UINT8 = 0
35
- INT8 = 1
36
- UINT16 = 2
37
- INT16 = 3
38
- UINT32 = 4
39
- INT32 = 5
33
+ UINT8 = 0
34
+ INT8 = 1
35
+ UINT16 = 2
36
+ INT16 = 3
37
+ UINT32 = 4
38
+ INT32 = 5
40
39
FLOAT32 = 6
41
- BOOL = 7
42
- STRING = 8
43
- ARRAY = 9
40
+ BOOL = 7
41
+ STRING = 8
42
+ ARRAY = 9
44
43
45
44
@staticmethod
46
- def get_type (value ):
47
- if isinstance (value , str ):
45
+ def get_type (val ):
46
+ if isinstance (val , str ):
48
47
return GGUFValueType .STRING
49
- elif isinstance (value , list ):
48
+ elif isinstance (val , list ):
50
49
return GGUFValueType .ARRAY
51
- elif isinstance (value , float ):
50
+ elif isinstance (val , float ):
52
51
return GGUFValueType .FLOAT32
53
- elif isinstance (value , bool ):
52
+ elif isinstance (val , bool ):
54
53
return GGUFValueType .BOOL
55
54
else :
56
55
return GGUFValueType .INT32
57
56
58
-
59
57
class GGUFWriter :
60
58
def __init__ (self , buffered_writer ):
61
59
self .buffered_writer = buffered_writer
@@ -72,81 +70,81 @@ def open(cls, path: str) -> "GGUFWriter":
72
70
return cls (f )
73
71
74
72
def write_key (self , key : str ):
75
- self .write_value (key , GGUFValueType .STRING )
73
+ self .write_val (key , GGUFValueType .STRING )
76
74
77
- def write_uint8 (self , key : str , value : int ):
75
+ def write_uint8 (self , key : str , val : int ):
78
76
self .write_key (key )
79
- self .write_value ( value , GGUFValueType .UINT8 )
77
+ self .write_val ( val , GGUFValueType .UINT8 )
80
78
81
- def write_int8 (self , key : str , value : int ):
79
+ def write_int8 (self , key : str , val : int ):
82
80
self .write_key (key )
83
- self .write_value ( value , GGUFValueType .INT8 )
81
+ self .write_val ( val , GGUFValueType .INT8 )
84
82
85
- def write_uint16 (self , key : str , value : int ):
83
+ def write_uint16 (self , key : str , val : int ):
86
84
self .write_key (key )
87
- self .write_value ( value , GGUFValueType .UINT16 )
85
+ self .write_val ( val , GGUFValueType .UINT16 )
88
86
89
- def write_int16 (self , key : str , value : int ):
87
+ def write_int16 (self , key : str , val : int ):
90
88
self .write_key (key )
91
- self .write_value ( value , GGUFValueType .INT16 )
89
+ self .write_val ( val , GGUFValueType .INT16 )
92
90
93
- def write_uint32 (self , key : str , value : int ):
91
+ def write_uint32 (self , key : str , val : int ):
94
92
self .write_key (key )
95
- self .write ( value , GGUFValueType .UINT32 )
93
+ self .write_val ( val , GGUFValueType .UINT32 )
96
94
97
- def write_int32 (self , key : str , value : int ):
95
+ def write_int32 (self , key : str , val : int ):
98
96
self .write_key (key )
99
- self .write_value ( value , GGUFValueType .INT32 )
97
+ self .write_val ( val , GGUFValueType .INT32 )
100
98
101
- def write_float32 (self , key : str , value : float ):
99
+ def write_float32 (self , key : str , val : float ):
102
100
self .write_key (key )
103
- self .write_value ( value , GGUFValueType .FLOAT32 )
101
+ self .write_val ( val , GGUFValueType .FLOAT32 )
104
102
105
- def write_bool (self , key : str , value : bool ):
103
+ def write_bool (self , key : str , val : bool ):
106
104
self .write_key (key )
107
- self .write_value ( value , GGUFValueType .BOOL )
105
+ self .write_val ( val , GGUFValueType .BOOL )
108
106
109
- def write_string (self , key : str , value : str ):
107
+ def write_string (self , key : str , val : str ):
110
108
self .write_key (key )
111
- self .write_value ( value , GGUFValueType .STRING )
109
+ self .write_val ( val , GGUFValueType .STRING )
112
110
113
- def write_array (self , key : str , value : list ):
114
- if not isinstance (value , list ):
111
+ def write_array (self , key : str , val : list ):
112
+ if not isinstance (val , list ):
115
113
raise ValueError ("Value must be a list for array type" )
116
114
117
115
self .write_key (key )
118
- self .write_value ( value , GGUFValueType .ARRAY )
119
-
120
- def write_value (self : str , value : Any , value_type : GGUFValueType = None ):
121
- if value_type is None :
122
- value_type = GGUFValueType .get_type (value )
123
-
124
- self .buffered_writer .write (struct .pack ("<I" , value_type ))
125
-
126
- if value_type == GGUFValueType .UINT8 :
127
- self .buffered_writer .write (struct .pack ("<B" , value ))
128
- elif value_type == GGUFValueType .INT8 :
129
- self .buffered_writer .write (struct .pack ("<b" , value ))
130
- elif value_type == GGUFValueType .UINT16 :
131
- self .buffered_writer .write (struct .pack ("<H" , value ))
132
- elif value_type == GGUFValueType .INT16 :
133
- self .buffered_writer .write (struct .pack ("<h" , value ))
134
- elif value_type == GGUFValueType .UINT32 :
135
- self .buffered_writer .write (struct .pack ("<I" , value ))
136
- elif value_type == GGUFValueType .INT32 :
137
- self .buffered_writer .write (struct .pack ("<i" , value ))
138
- elif value_type == GGUFValueType .FLOAT32 :
139
- self .buffered_writer .write (struct .pack ("<f" , value ))
140
- elif value_type == GGUFValueType .BOOL :
141
- self .buffered_writer .write (struct .pack ("?" , value ))
142
- elif value_type == GGUFValueType .STRING :
143
- encoded_value = value .encode ("utf8" )
144
- self .buffered_writer .write (struct .pack ("<I" , len (encoded_value )))
145
- self .buffered_writer .write (encoded_value )
146
- elif value_type == GGUFValueType .ARRAY :
147
- self .buffered_writer .write (struct .pack ("<I" , len (value )))
148
- for item in value :
149
- self .write_value (item )
116
+ self .write_val ( val , GGUFValueType .ARRAY )
117
+
118
+ def write_val (self : str , val : Any , vtype : GGUFValueType = None ):
119
+ if vtype is None :
120
+ vtype = GGUFValueType .get_type (val )
121
+
122
+ self .buffered_writer .write (struct .pack ("<I" , vtype ))
123
+
124
+ if vtype == GGUFValueType .UINT8 :
125
+ self .buffered_writer .write (struct .pack ("<B" , val ))
126
+ elif vtype == GGUFValueType .INT8 :
127
+ self .buffered_writer .write (struct .pack ("<b" , val ))
128
+ elif vtype == GGUFValueType .UINT16 :
129
+ self .buffered_writer .write (struct .pack ("<H" , val ))
130
+ elif vtype == GGUFValueType .INT16 :
131
+ self .buffered_writer .write (struct .pack ("<h" , val ))
132
+ elif vtype == GGUFValueType .UINT32 :
133
+ self .buffered_writer .write (struct .pack ("<I" , val ))
134
+ elif vtype == GGUFValueType .INT32 :
135
+ self .buffered_writer .write (struct .pack ("<i" , val ))
136
+ elif vtype == GGUFValueType .FLOAT32 :
137
+ self .buffered_writer .write (struct .pack ("<f" , val ))
138
+ elif vtype == GGUFValueType .BOOL :
139
+ self .buffered_writer .write (struct .pack ("?" , val ))
140
+ elif vtype == GGUFValueType .STRING :
141
+ encoded_val = val .encode ("utf8" )
142
+ self .buffered_writer .write (struct .pack ("<I" , len (encoded_val )))
143
+ self .buffered_writer .write (encoded_val )
144
+ elif vtype == GGUFValueType .ARRAY :
145
+ self .buffered_writer .write (struct .pack ("<I" , len (val )))
146
+ for item in val :
147
+ self .write_val (item )
150
148
else :
151
149
raise ValueError ("Invalid GGUF metadata value type" )
152
150
0 commit comments