6
6
7
7
import torch
8
8
import gguf
9
+ from quantize import group_dequantize_tensor_from_qparams
10
+
11
+ def to_float (t : gguf .gguf_reader .ReaderTensor ):
12
+ """
13
+ Unpack and dequantize GGUF tensor to torch tensor of type torch.float32.
14
+ """
15
+
16
+ # All other weights are dequantized to float
17
+ if t .tensor_type == gguf .GGMLQuantizationType .Q4_0 :
18
+ return group_dequantize_tensor_from_qparams (* Q4_0 .unpack (t ), Q4_0 .n_bit , Q4_0 .groupsize ).to (torch .float32 )
19
+ elif t .tensor_type == gguf .GGMLQuantizationType .Q6_K :
20
+ return group_dequantize_tensor_from_qparams (* Q6_K .unpack (t ), Q6_K .n_bit , Q6_K .groupsize ).to (torch .float32 )
21
+ elif t .tensor_type == gguf .GGMLQuantizationType .F16 :
22
+ return F16 .unpack (t ).to (torch .float32 )
23
+ elif t .tensor_type == gguf .GGMLQuantizationType .F32 :
24
+ return F32 .unpack (t ).to (torch .float32 )
25
+ else :
26
+ raise ValueError (f"Unsupported tensor type { t .tensor_type } " )
27
+
28
+
29
+ def test_by_to_float (source_file : str , target_file : str ) -> None :
30
+ """
31
+ Tests methods in this file by using the to_float method, and comparing with a correct
32
+ reference. Raises error if there is a mismatch.
33
+
34
+ In more detail, a GGUF source_file with various GGUF tensor types is parsed, and these
35
+ tensors are converted with to_float. These are then compared against a GGUF target_file.
36
+ The target GGUF file must only contain F32 tensors, and should be generated by a method
37
+ that is known to be correct.
38
+ """
39
+
40
+ gguf_sources = {t .name : t for t in gguf .GGUFReader (source_file , "r" ).tensors }
41
+ gguf_targets = {t .name : t for t in gguf .GGUFReader (target_file , "r" ).tensors }
42
+
43
+ for t in gguf_targets .values ():
44
+ assert t .tensor_type == gguf .GGMLQuantizationType .F32 , f"target_file must only contain F32 tensors, but found tensor { t .name } with type { repr (t .tensor_type )} ."
45
+ assert gguf_sources .keys () == gguf_targets .keys (), "source_file and target_file should have the same tensors (by name)"
46
+
47
+ for k in gguf_sources :
48
+ source = to_float (gguf_sources [k ])
49
+ target = to_float (gguf_targets [k ])
50
+
51
+ if not torch .allclose (source , target ):
52
+ print (f"After calling to_float on source tensor { k } of type { repr (gguf_sources [k ].tensor_type )} it does not match its target." )
53
+ print ("First 5 elements of converted source: " , source .reshape (- 1 )[0 :5 ])
54
+ print ("First 5 elements of target: " , target .reshape (- 1 )[0 :5 ])
55
+ assert False , "found mismatch"
56
+
57
+ print ("All tensors match." )
58
+
9
59
10
60
class F16 :
11
61
@staticmethod
@@ -14,7 +64,7 @@ def unpack(gguf_tensor: gguf.gguf_reader.ReaderTensor):
14
64
Unpacks GGUF F16 tensor.
15
65
"""
16
66
assert gguf_tensor .tensor_type == gguf .GGMLQuantizationType .F16
17
- reversed_shape = gguf_tensor .shape [::- 1 ] # TODO: GGUF tensors are reversed
67
+ reversed_shape = gguf_tensor .shape [::- 1 ]
18
68
new_tensor = gguf_tensor .data .reshape (reversed_shape )
19
69
return torch .from_numpy (new_tensor ).to (torch .float16 )
20
70
@@ -25,7 +75,7 @@ def unpack(gguf_tensor: gguf.gguf_reader.ReaderTensor):
25
75
Unpacks GGUF F32 tensor.
26
76
"""
27
77
assert gguf_tensor .tensor_type == gguf .GGMLQuantizationType .F32
28
- reversed_shape = gguf_tensor .shape [::- 1 ] # TODO: GGUF tensors are reversed
78
+ reversed_shape = gguf_tensor .shape [::- 1 ]
29
79
new_tensor = gguf_tensor .data .reshape (reversed_shape )
30
80
return torch .from_numpy (new_tensor ).to (torch .float32 )
31
81
@@ -61,7 +111,7 @@ def unpack(gguf_tensor: gguf.gguf_reader.ReaderTensor):
61
111
62
112
assert gguf_tensor .tensor_type == gguf .GGMLQuantizationType .Q4_0
63
113
assert len (gguf_tensor .shape ) == 2
64
- nc , nr = gguf_tensor .shape # TODO: CHECK THIS. GGUF TENSOR REVERSED?
114
+ nc , nr = gguf_tensor .shape # GGUF tensor has reversed shape
65
115
66
116
QK4_0 = 32 # groupsize
67
117
@@ -84,7 +134,7 @@ def unpack(gguf_tensor: gguf.gguf_reader.ReaderTensor):
84
134
# Check we finished parsing
85
135
assert curr == block_q4_0_size
86
136
87
- # Unpack quantized values. Unlike the code in ggml-quants.c, we do not subtract 16
137
+ # Unpack quantized values. Unlike the code in ggml-quants.c, we do not subtract 8
88
138
x0 = qs & 0x0F
89
139
x1 = qs >> 4
90
140
@@ -117,8 +167,6 @@ def unpack(gguf_tensor: gguf.gguf_reader.ReaderTensor):
117
167
* s is a torch.float32 tensor of shape (nr, -1) with one scale per group
118
168
* z is a torch.float32 tensor of shape (nr, -1) with one zero per group
119
169
120
- There is one element of s/z per group of 32 elements of 4.
121
-
122
170
Note that z is always zero because Q6_k is a scale-only scheme.
123
171
124
172
See https://github.com/ggerganov/llama.cpp/blob/master/ggml-common.h for definition of block_q6_K:
@@ -142,7 +190,7 @@ def unpack(gguf_tensor: gguf.gguf_reader.ReaderTensor):
142
190
"""
143
191
assert gguf_tensor .tensor_type == gguf .GGMLQuantizationType .Q6_K
144
192
assert len (gguf_tensor .shape ) == 2
145
- nc , nr = gguf_tensor .shape # TODO: CHECK THIS. GGUF TENSOR REVERSED?
193
+ nc , nr = gguf_tensor .shape # GGUF tensor has reversed shape
146
194
QK_K = 256
147
195
148
196
# Parse block_q6_K
0 commit comments