Skip to content

Commit 60b29ea

Browse files
committed
More constants from gguf.
1 parent e2f13a3 commit 60b29ea

File tree

1 file changed

+10
-10
lines changed

1 file changed

+10
-10
lines changed

convert_grok.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,9 @@
3535

3636
import gguf
3737

38-
GGML_QK8_0 = 32
39-
GGML_QK4_0 = 32
40-
GGML_QK4_1 = 32
38+
QK8_0 = gguf.GGML_QUANT_SIZES[gguf.GGMLQuantizationType.Q8_0][0]
39+
QK4_0 = gguf.GGML_QUANT_SIZES[gguf.GGMLQuantizationType.Q4_0][0]
40+
QK4_1 = gguf.GGML_QUANT_SIZES[gguf.GGMLQuantizationType.Q4_1][0]
4141

4242

4343
# Heuristic to avoid having to fully parse pickle files.
@@ -125,8 +125,8 @@ def get_weights(fn):
125125

126126
def quantize_q8_0(tensor: torch.Tensor) -> torch.CharTensor:
127127
# equivalent to ggml_quantize_q8_0 in ggml.c (modulo rounding away from zero)
128-
assert tensor.shape[1] % GGML_QK8_0 == 0
129-
tensor = tensor.reshape(-1, GGML_QK8_0)
128+
assert tensor.shape[1] % QK8_0 == 0
129+
tensor = tensor.reshape(-1, QK8_0)
130130
scale = tensor.abs().max(dim=-1, keepdim=True).values / ((1 << 7) - 1)
131131
tensor = (tensor / scale).round().clamp(min=-128, max=127).char()
132132
# add scale into each block
@@ -136,8 +136,8 @@ def quantize_q8_0(tensor: torch.Tensor) -> torch.CharTensor:
136136

137137
def quantize_q4_0(tensor: torch.Tensor) -> torch.CharTensor:
138138
# equivalent to ggml_quantize_q4_0 in ggml.c (modulo rounding away from zero)
139-
assert tensor.shape[1] % GGML_QK4_0 == 0
140-
tensor = tensor.reshape(-1, GGML_QK4_0)
139+
assert tensor.shape[1] % QK4_0 == 0
140+
tensor = tensor.reshape(-1, QK4_0)
141141
abs_max_indices = tensor.abs().max(dim=-1, keepdim=True).indices
142142
max_values = torch.take_along_dim(tensor, abs_max_indices, dim=-1)
143143
scale = max_values / -8
@@ -151,8 +151,8 @@ def quantize_q4_0(tensor: torch.Tensor) -> torch.CharTensor:
151151

152152
def quantize_q4_1(tensor: torch.Tensor) -> torch.CharTensor:
153153
# equivalent to ggml_quantize_q4_1 in ggml.c (modulo rounding away from zero)
154-
assert tensor.shape[1] % GGML_QK4_1 == 0
155-
tensor = tensor.reshape(-1, GGML_QK4_1)
154+
assert tensor.shape[1] % QK4_1 == 0
155+
tensor = tensor.reshape(-1, QK4_1)
156156
abs_max_indices = tensor.max(dim=-1, keepdim=True).indices
157157
max_values = torch.take_along_dim(tensor, abs_max_indices, dim=-1)
158158
abs_min_indices = tensor.min(dim=-1, keepdim=True).indices
@@ -188,7 +188,7 @@ def maybe_quantize_tensor(tensor, ggml_type):
188188

189189
def get_dtype_and_ggml_type(name, tensor, ggml_type):
190190
if tensor.ndim in (2, 3) and "ffn_gate_inp" not in name:
191-
if tensor.shape[1] % GGML_QK8_0 == 0:
191+
if tensor.shape[1] % QK8_0 == 0:
192192
return np.int8, ggml_type
193193
else:
194194
return np.float16, gguf.GGMLQuantizationType.F16

0 commit comments

Comments
 (0)