Skip to content

Commit 678916f

Browse files
committed
ggml-quants : use ceiling division when quantizing q1_3
1 parent 5fefd47 commit 678916f

File tree

4 files changed

+12
-7
lines changed

4 files changed

+12
-7
lines changed

convert-hf-to-gguf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,7 @@ def write_tensors(self):
334334
shape = gguf.quant_shape_from_byte_shape(data.shape, data_qtype) if data.dtype == np.uint8 else data.shape
335335

336336
# reverse shape to make it similar to the internal ggml dimension order
337-
shape_str = f"{{{', '.join(str(n) for n in reversed(shape))}}}"
337+
shape_str = f"{{{', '.join(str(n) for n in reversed(shape)) or '1'}}}"
338338

339339
# n_dims is implicit in the shape
340340
logger.info(f"{f'%-{max_name_len}s' % f'{new_name},'} {old_dtype} --> {data_qtype.name}, shape = {shape_str}")

ggml-quants.c

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3389,8 +3389,8 @@ void quantize_row_q1_3_reference(const float * restrict x, block_q1_3 * restrict
33893389
int xi = nearest_int(x[j]);
33903390
uint8_t xt = xi < 0 ? 0 : xi == 0 ? 1 : 2;
33913391
q[j] += xt * pow3[4];
3392-
q[j] = ((uint16_t)q[j] * 256) / pow3[5];
3393-
q[j] += (uint8_t)(q[j] != 0);
3392+
// ceiling division
3393+
q[j] = ((uint16_t)q[j] * 256 + (pow3[5] - 1)) / pow3[5];
33943394
y[i].q[j] = q[j];
33953395
}
33963396
x += sizeof(y->q);
@@ -3403,8 +3403,8 @@ void quantize_row_q1_3_reference(const float * restrict x, block_q1_3 * restrict
34033403
qb += xt * pow3[m];
34043404
}
34053405
x += 4;
3406-
qb = ((uint16_t)qb * 256) / pow3[5];
3407-
qb += (uint8_t)(qb != 0);
3406+
// ceiling division
3407+
qb = ((uint16_t)qb * 256 + (pow3[5] - 1)) / pow3[5];
34083408
y[i].qs[j] = qb;
34093409
}
34103410
}

gguf-py/gguf/quants.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,8 +149,7 @@ def __quantize_q1_3_rows(n: np.ndarray) -> np.ndarray:
149149
q4 = np.sum(q4 * pow3.reshape((1, 4)), axis=1, keepdims=True)
150150
q48 = q48 + (q12 * 81)
151151
q = np.concatenate([q48, q4], axis=1)
152-
q = ((q.astype(np.uint16) * 256) // 243).astype(np.uint8)
153-
q = np.where(q != 0, q + 1, 0)
152+
q = (((q.astype(np.uint16) * 256) + (243 - 1)) // 243).astype(np.uint8)
154153

155154
return q.reshape(__quantize_q1_3_shape_change(shape))
156155

tests/test-quantize-fns.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,13 @@
1515

1616
constexpr float MAX_QUANTIZATION_REFERENCE_ERROR = 0.0001f;
1717
constexpr float MAX_QUANTIZATION_TOTAL_ERROR = 0.002f;
18+
constexpr float MAX_QUANTIZATION_TOTAL_ERROR_BITNET = 0.015625f;
1819
constexpr float MAX_QUANTIZATION_TOTAL_ERROR_2BITS = 0.0075f;
1920
constexpr float MAX_QUANTIZATION_TOTAL_ERROR_3BITS = 0.0040f;
2021
constexpr float MAX_QUANTIZATION_TOTAL_ERROR_3BITS_XXS = 0.0050f;
2122
constexpr float MAX_DOT_PRODUCT_ERROR = 0.02f;
2223
constexpr float MAX_DOT_PRODUCT_ERROR_LOWBIT = 0.04f;
24+
constexpr float MAX_DOT_PRODUCT_ERROR_BITNET = 0.5f;
2325

2426
static const char* RESULT_STR[] = {"ok", "FAILED"};
2527

@@ -144,6 +146,8 @@ int main(int argc, char * argv[]) {
144146
if (qfns.from_float && qfns.to_float) {
145147
const float total_error = total_quantization_error(qfns, test_size, test_data.data());
146148
const float max_quantization_error =
149+
type == GGML_TYPE_Q1_3 ? MAX_QUANTIZATION_TOTAL_ERROR_BITNET :
150+
type == GGML_TYPE_Q2_2 ? MAX_QUANTIZATION_TOTAL_ERROR_BITNET :
147151
type == GGML_TYPE_Q2_K ? MAX_QUANTIZATION_TOTAL_ERROR_2BITS :
148152
type == GGML_TYPE_IQ2_S ? MAX_QUANTIZATION_TOTAL_ERROR_2BITS :
149153
type == GGML_TYPE_Q3_K ? MAX_QUANTIZATION_TOTAL_ERROR_3BITS :
@@ -166,6 +170,8 @@ int main(int argc, char * argv[]) {
166170
const float max_allowed_error = type == GGML_TYPE_Q2_K || type == GGML_TYPE_IQ2_XS || type == GGML_TYPE_IQ2_XXS ||
167171
type == GGML_TYPE_IQ3_XXS || type == GGML_TYPE_IQ3_S || type == GGML_TYPE_IQ2_S
168172
? MAX_DOT_PRODUCT_ERROR_LOWBIT
173+
: type == GGML_TYPE_Q2_2 || type == GGML_TYPE_Q1_3
174+
? MAX_DOT_PRODUCT_ERROR_BITNET
169175
: MAX_DOT_PRODUCT_ERROR;
170176
failed = !(vec_dot_error < max_allowed_error);
171177
num_failed += failed;

0 commit comments

Comments
 (0)