Skip to content

Commit a5c893a

Browse files
committed
test-quantize: fix for q8_0 intermediates
1 parent ba6f75e commit a5c893a

File tree

2 files changed

+26
-10
lines changed

2 files changed

+26
-10
lines changed

tests/test-quantize-fns.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212

1313
const float MAX_QUANTIZATION_REFERENCE_ERROR = 0.0001;
1414
const float MAX_QUANTIZATION_TOTAL_ERROR = 0.002;
15-
// TODO: check why q4_1 is high
1615
const float MAX_DOT_PRODUCT_ERROR = 0.02;
1716

1817
const char* RESULT_STR[] = {"ok", "FAILED"};
@@ -71,10 +70,10 @@ float dot_product(const float * a1, const float * a2, size_t test_size) {
7170
// Total dot product error
7271
float dot_product_error(quantize_fns_t & qfns, size_t test_size, const float * test_data1, const float *test_data2) {
7372
std::vector<uint8_t> tmp_q1(test_size);
74-
std::vector<uint8_t> tmp_q2(test_size);
73+
std::vector<uint8_t> tmp_q2(test_size*2);
7574

7675
qfns.quantize_row_q(test_data1, tmp_q1.data(), test_size);
77-
qfns.quantize_row_q(test_data2, tmp_q2.data(), test_size);
76+
qfns.quantize_row_q_dot(test_data2, tmp_q2.data(), test_size);
7877

7978
float result = INFINITY;
8079
qfns.vec_dot_q(test_size, &result, tmp_q1.data(), tmp_q2.data());
@@ -121,7 +120,7 @@ int main(int argc, char * argv[]) {
121120
ggml_type type = (ggml_type) i;
122121
quantize_fns_t qfns = ggml_internal_get_quantize_fn(i);
123122

124-
if (qfns.quantize_row_q) {
123+
if (qfns.quantize_row_q && qfns.dequantize_row_q) {
125124
const float total_error = total_quantization_error(qfns, test_size, test_data.data());
126125
failed = !(total_error < MAX_QUANTIZATION_TOTAL_ERROR);
127126
num_failed += failed;

tests/test-quantize-perf.cpp

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ struct quantize_perf_params {
3030
bool op_quantize_row_q_reference = false;
3131
bool op_quantize_row_q = false;
3232
bool op_dequantize_row_q = false;
33+
bool op_quantize_row_q_dot = false;
3334
bool op_vec_dot_q = false;
3435
};
3536

@@ -147,6 +148,8 @@ int main(int argc, char * argv[]) {
147148
params.op_quantize_row_q = true;
148149
} else if (op == "dequantize_row_q") {
149150
params.op_dequantize_row_q = true;
151+
} else if (op == "quantize_row_q_dot") {
152+
params.op_quantize_row_q_dot = true;
150153
} else if (op == "vec_dot_q") {
151154
params.op_vec_dot_q = true;
152155
} else {
@@ -184,8 +187,8 @@ int main(int argc, char * argv[]) {
184187
if (params.test_sizes.empty()) {
185188
params.test_sizes.push_back(L1_SIZE);
186189
}
187-
if (!(params.op_quantize_row_q_reference || params.op_quantize_row_q || params.op_dequantize_row_q || params.op_vec_dot_q)) {
188-
params.op_quantize_row_q_reference = params.op_quantize_row_q = params.op_dequantize_row_q = params.op_vec_dot_q = true;
190+
if (!(params.op_quantize_row_q_reference || params.op_quantize_row_q || params.op_dequantize_row_q || params.op_quantize_row_q_dot || params.op_vec_dot_q)) {
191+
params.op_quantize_row_q_reference = params.op_quantize_row_q = params.op_dequantize_row_q = params.op_quantize_row_q_dot = params.op_vec_dot_q = true;
189192
}
190193

191194
std::sort(params.test_sizes.begin(), params.test_sizes.end());
@@ -225,7 +228,7 @@ int main(int argc, char * argv[]) {
225228
if (qfns.quantize_row_q) {
226229
printf("%s\n", ggml_type_name(type));
227230

228-
if (params.op_quantize_row_q_reference) {
231+
if (params.op_quantize_row_q_reference && qfns.quantize_row_q_reference) {
229232
printf(" quantize_row_q_reference\n");
230233
for (size_t size : params.test_sizes) {
231234
printf(" %zu values (%.2f MB)\n", size, 4*size/(float)(1024*1024));
@@ -239,7 +242,7 @@ int main(int argc, char * argv[]) {
239242
printf("\n");
240243
}
241244

242-
if (params.op_quantize_row_q) {
245+
if (params.op_quantize_row_q && qfns.quantize_row_q) {
243246
printf(" quantize_row_q\n");
244247
for (size_t size : params.test_sizes) {
245248
printf(" %zu values (%.2f MB)\n", size, 4*size/(float)(1024*1024));
@@ -253,7 +256,7 @@ int main(int argc, char * argv[]) {
253256
printf("\n");
254257
}
255258

256-
if (params.op_dequantize_row_q) {
259+
if (params.op_dequantize_row_q && qfns.dequantize_row_q) {
257260
printf(" dequantize_row_q\n");
258261
qfns.quantize_row_q(test_data1, test_q1, largest);
259262
for (size_t size : params.test_sizes) {
@@ -268,7 +271,21 @@ int main(int argc, char * argv[]) {
268271
printf("\n");
269272
}
270273

271-
if (params.op_vec_dot_q) {
274+
if (params.op_quantize_row_q_dot && qfns.quantize_row_q_dot) {
275+
printf(" quantize_row_q_dot\n");
276+
for (size_t size : params.test_sizes) {
277+
printf(" %zu values (%.2f MB)\n", size, 4*size/(float)(1024*1024));
278+
auto quantize_fn = [&](void ) {
279+
qfns.quantize_row_q_dot(test_data1, test_q1, size);
280+
return test_q1[0];
281+
};
282+
size_t quantized_size = size / ggml_blck_size(type) * ggml_type_size(type);
283+
benchmark_function(size, quantized_size, quantize_fn);
284+
}
285+
printf("\n");
286+
}
287+
288+
if (params.op_vec_dot_q && qfns.vec_dot_q) {
272289
printf(" vec_dot_q\n");
273290
qfns.quantize_row_q(test_data1, test_q1, largest);
274291
qfns.quantize_row_q(test_data2, test_q2, largest);

0 commit comments

Comments
 (0)