Skip to content

Commit 3770f4f

Browse files
committed
Generalize quantize_fns for simpler FP16 handling
1 parent b8c8dda commit 3770f4f

File tree

8 files changed

+173
-611
lines changed

8 files changed

+173
-611
lines changed

examples/quantize-stats/quantize-stats.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ void test_roundtrip_on_chunk(
147147
const ggml_tensor * layer,
148148
int64_t offset,
149149
int64_t chunk_size,
150-
const quantize_fns_t & qfns,
150+
const ggml_type_handling_t & qfns,
151151
bool use_reference,
152152
float * input_scratch,
153153
char * quantized_scratch,
@@ -163,11 +163,11 @@ void test_roundtrip_on_chunk(
163163
}
164164

165165
if (use_reference) {
166-
qfns.quantize_row_q_reference(input_scratch, quantized_scratch, chunk_size);
166+
qfns.from_float_reference(input_scratch, quantized_scratch, chunk_size);
167167
} else {
168-
qfns.quantize_row_q(input_scratch, quantized_scratch, chunk_size);
168+
qfns.from_float(input_scratch, quantized_scratch, chunk_size);
169169
}
170-
qfns.dequantize_row_q(quantized_scratch, output_scratch, chunk_size);
170+
qfns.to_float(quantized_scratch, output_scratch, chunk_size);
171171

172172
update_error_stats(chunk_size, input_scratch, output_scratch, stats);
173173
}
@@ -177,7 +177,7 @@ void test_roundtrip_on_chunk(
177177
void test_roundtrip_on_layer(
178178
std::string & name,
179179
bool print_layer_stats,
180-
const quantize_fns_t & qfns,
180+
const ggml_type_handling_t & qfns,
181181
bool use_reference,
182182
const ggml_tensor * layer,
183183
std::vector<float> & input_scratch,
@@ -388,8 +388,8 @@ int main(int argc, char ** argv) {
388388
if (!params.include_types.empty() && std::find(params.include_types.begin(), params.include_types.end(), i) == params.include_types.end()) {
389389
continue;
390390
}
391-
quantize_fns_t qfns = ggml_internal_get_quantize_fn(i);
392-
if (qfns.quantize_row_q && qfns.dequantize_row_q) {
391+
ggml_type_handling_t qfns = ggml_internal_get_type_handling(type);
392+
if (qfns.from_float && qfns.to_float) {
393393
if (params.verbose) {
394394
printf("testing %s ...\n", ggml_type_name(type));
395395
}

ggml.c

Lines changed: 109 additions & 544 deletions
Large diffs are not rendered by default.

ggml.h

Lines changed: 13 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -224,8 +224,8 @@ extern "C" {
224224
GGML_API float ggml_fp16_to_fp32(ggml_fp16_t x);
225225
GGML_API ggml_fp16_t ggml_fp32_to_fp16(float x);
226226

227-
GGML_API void ggml_fp16_to_fp32_row(const ggml_fp16_t * x, float * y, size_t n);
228-
GGML_API void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, size_t n);
227+
GGML_API void ggml_fp16_to_fp32_row(const ggml_fp16_t * x, float * y, int n);
228+
GGML_API void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, int n);
229229

230230
struct ggml_object;
231231
struct ggml_context;
@@ -1487,26 +1487,19 @@ extern "C" {
14871487
// Internal types and functions exposed for tests and benchmarks
14881488
//
14891489

1490-
#ifdef __cplusplus
1491-
// restrict not standard in C++
1492-
#define GGML_RESTRICT
1493-
#else
1494-
#define GGML_RESTRICT restrict
1495-
#endif
1496-
typedef void (*dequantize_row_q_t)(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
1497-
typedef void (*quantize_row_q_t) (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
1498-
typedef void (*vec_dot_q_t) (const int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT x, const void * GGML_RESTRICT y);
1490+
typedef void (*ggml_to_float_t)(const void * x, float * y, int k);
1491+
typedef void (*ggml_from_float_t)(const float * x, void * y, int k);
1492+
typedef void (*ggml_vec_dot_t)(const int n, float * s, const void * x, const void * y);
14991493

15001494
typedef struct {
1501-
dequantize_row_q_t dequantize_row_q;
1502-
quantize_row_q_t quantize_row_q;
1503-
quantize_row_q_t quantize_row_q_reference;
1504-
quantize_row_q_t quantize_row_q_dot;
1505-
vec_dot_q_t vec_dot_q;
1506-
enum ggml_type vec_dot_type;
1507-
} quantize_fns_t;
1508-
1509-
quantize_fns_t ggml_internal_get_quantize_fn(size_t i);
1495+
ggml_to_float_t to_float;
1496+
ggml_from_float_t from_float;
1497+
ggml_from_float_t from_float_reference;
1498+
ggml_vec_dot_t vec_dot;
1499+
enum ggml_type vec_dot_type;
1500+
} ggml_type_handling_t;
1501+
1502+
ggml_type_handling_t ggml_internal_get_type_handling(enum ggml_type i);
15101503

15111504
#ifdef __cplusplus
15121505
}

llama.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2214,10 +2214,10 @@ static void llama_convert_tensor_internal(const llama_load_tensor & tensor, llam
22142214
}
22152215
float * f32_output = (float *) output.addr;
22162216

2217-
quantize_fns_t qtype;
2217+
ggml_type_handling_t qtype;
22182218
if (ggml_is_quantized(tensor.type)) {
2219-
qtype = ggml_internal_get_quantize_fn(tensor.type);
2220-
if (qtype.dequantize_row_q == NULL) {
2219+
qtype = ggml_internal_get_type_handling(tensor.type);
2220+
if (qtype.to_float == NULL) {
22212221
throw std::runtime_error(format("type %s unsupported for integer quantization: no dequantization available", ggml_type_name(tensor.type)));
22222222
}
22232223
} else if (tensor.type != GGML_TYPE_F16) {
@@ -2228,7 +2228,7 @@ static void llama_convert_tensor_internal(const llama_load_tensor & tensor, llam
22282228
if (tensor.type == GGML_TYPE_F16) {
22292229
ggml_fp16_to_fp32_row((ggml_fp16_t *)tensor.data, f32_output, nelements);
22302230
} else if (ggml_is_quantized(tensor.type)) {
2231-
qtype.dequantize_row_q(tensor.data, f32_output, nelements);
2231+
qtype.to_float(tensor.data, f32_output, nelements);
22322232
} else {
22332233
LLAMA_ASSERT(false); // unreachable
22342234
}
@@ -2253,7 +2253,7 @@ static void llama_convert_tensor_internal(const llama_load_tensor & tensor, llam
22532253
if (typ == GGML_TYPE_F16) {
22542254
ggml_fp16_to_fp32_row((ggml_fp16_t *)inbuf, outbuf, nels);
22552255
} else {
2256-
qtype.dequantize_row_q(inbuf, outbuf, nels);
2256+
qtype.to_float(inbuf, outbuf, nels);
22572257
}
22582258
};
22592259
workers.push_back(std::thread(compute, tensor.type, tensor.data + in_buff_offs, f32_output + out_buff_offs, thr_elems));

pocs/vdot/q8dot.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ int main(int argc, char** argv) {
136136

137137
auto ggml_type = type == 0 ? GGML_TYPE_Q4_0 : GGML_TYPE_Q4_1;
138138

139-
auto funcs = ggml_internal_get_quantize_fn(ggml_type);
139+
auto funcs = ggml_internal_get_type_handling(ggml_type);
140140

141141
Stat simple, ggml;
142142

@@ -156,8 +156,8 @@ int main(int argc, char** argv) {
156156

157157
t1 = std::chrono::high_resolution_clock::now();
158158
float fs;
159-
if (type == 0) funcs.vec_dot_q(kVecSize * QK4_1, &fs, x40.data(), y.data());
160-
else funcs.vec_dot_q(kVecSize * QK4_1, &fs, x41.data(), y.data());
159+
if (type == 0) funcs.vec_dot(kVecSize * QK4_1, &fs, x40.data(), y.data());
160+
else funcs.vec_dot(kVecSize * QK4_1, &fs, x41.data(), y.data());
161161
t2 = std::chrono::high_resolution_clock::now();
162162
t = 1e-3*std::chrono::duration_cast<std::chrono::nanoseconds>(t2-t1).count();
163163
if (iloop > 3) ggml.addResult(fs, t);

pocs/vdot/vdot.cpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ int main(int argc, char** argv) {
235235
int n4 = useQ4_1 ? kVecSize / QK4_1 : kVecSize / QK4_0; n4 = 64*((n4 + 63)/64);
236236
int n8 = kVecSize / QK8_0; n8 = 64*((n8 + 63)/64);
237237

238-
auto funcs = useQ4_1 ? ggml_internal_get_quantize_fn(GGML_TYPE_Q4_1) : ggml_internal_get_quantize_fn(GGML_TYPE_Q4_0);
238+
auto funcs = useQ4_1 ? ggml_internal_get_type_handling(GGML_TYPE_Q4_1) : ggml_internal_get_type_handling(GGML_TYPE_Q4_0);
239239

240240
std::vector<block_q4_0> q40;
241241
std::vector<block_q4_1> q41;
@@ -261,9 +261,9 @@ int main(int argc, char** argv) {
261261
// Note, we do not include this in the timing as in practical application
262262
// we already have the quantized model weights.
263263
if (useQ4_1) {
264-
funcs.quantize_row_q(x1.data(), q41.data(), kVecSize);
264+
funcs.from_float(x1.data(), q41.data(), kVecSize);
265265
} else {
266-
funcs.quantize_row_q(x1.data(), q40.data(), kVecSize);
266+
funcs.from_float(x1.data(), q40.data(), kVecSize);
267267
}
268268

269269
// Now measure time the dot product needs using the "scalar" version above
@@ -282,9 +282,10 @@ int main(int argc, char** argv) {
282282
dot_q4_q8(kVecSize, &result, q40.data(), q8.data());
283283
}
284284
else {
285-
funcs.quantize_row_q_dot(y1.data(), q8.data(), kVecSize);
286-
if (useQ4_1) funcs.vec_dot_q(kVecSize, &result, q41.data(), q8.data());
287-
else funcs.vec_dot_q(kVecSize, &result, q40.data(), q8.data());
285+
auto vdot = ggml_internal_get_type_handling(funcs.vec_dot_type);
286+
vdot.from_float(y1.data(), q8.data(), kVecSize);
287+
if (useQ4_1) funcs.vec_dot(kVecSize, &result, q41.data(), q8.data());
288+
else funcs.vec_dot(kVecSize, &result, q40.data(), q8.data());
288289
}
289290
sumq += result;
290291
t2 = std::chrono::high_resolution_clock::now();

tests/test-quantize-fns.cpp

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -40,26 +40,26 @@ float array_rmse(const float * a1, const float * a2, size_t n) {
4040
}
4141

4242
// Total quantization error on test data
43-
float total_quantization_error(quantize_fns_t & qfns, size_t test_size, const float * test_data) {
43+
float total_quantization_error(ggml_type_handling_t & qfns, size_t test_size, const float * test_data) {
4444
std::vector<uint8_t> tmp_q(2*test_size);
4545
std::vector<float> tmp_out(test_size);
4646

47-
qfns.quantize_row_q(test_data, tmp_q.data(), test_size);
48-
qfns.dequantize_row_q(tmp_q.data(), tmp_out.data(), test_size);
47+
qfns.from_float(test_data, tmp_q.data(), test_size);
48+
qfns.to_float(tmp_q.data(), tmp_out.data(), test_size);
4949
return array_rmse(test_data, tmp_out.data(), test_size);
5050
}
5151

5252
// Total quantization error on test data
53-
float reference_quantization_error(quantize_fns_t & qfns, size_t test_size, const float * test_data) {
53+
float reference_quantization_error(ggml_type_handling_t & qfns, size_t test_size, const float * test_data) {
5454
std::vector<uint8_t> tmp_q(2*test_size);
5555
std::vector<float> tmp_out(test_size);
5656
std::vector<float> tmp_out_ref(test_size);
5757

58-
qfns.quantize_row_q(test_data, tmp_q.data(), test_size);
59-
qfns.dequantize_row_q(tmp_q.data(), tmp_out.data(), test_size);
58+
qfns.from_float(test_data, tmp_q.data(), test_size);
59+
qfns.to_float(tmp_q.data(), tmp_out.data(), test_size);
6060

61-
qfns.quantize_row_q_reference(test_data, tmp_q.data(), test_size);
62-
qfns.dequantize_row_q(tmp_q.data(), tmp_out_ref.data(), test_size);
61+
qfns.from_float_reference(test_data, tmp_q.data(), test_size);
62+
qfns.to_float(tmp_q.data(), tmp_out_ref.data(), test_size);
6363

6464
return array_rmse(tmp_out.data(), tmp_out_ref.data(), test_size);
6565
}
@@ -73,15 +73,17 @@ float dot_product(const float * a1, const float * a2, size_t test_size) {
7373
}
7474

7575
// Total dot product error
76-
float dot_product_error(quantize_fns_t & qfns, size_t test_size, const float * test_data1, const float *test_data2) {
76+
float dot_product_error(ggml_type_handling_t & qfns, size_t test_size, const float * test_data1, const float *test_data2) {
7777
std::vector<uint8_t> tmp_q1(2*test_size);
7878
std::vector<uint8_t> tmp_q2(2*test_size);
7979

80-
qfns.quantize_row_q (test_data1, tmp_q1.data(), test_size);
81-
qfns.quantize_row_q_dot(test_data2, tmp_q2.data(), test_size);
80+
auto vdot = ggml_internal_get_type_handling(qfns.vec_dot_type);
81+
82+
qfns.from_float(test_data1, tmp_q1.data(), test_size);
83+
vdot.from_float(test_data2, tmp_q2.data(), test_size);
8284

8385
float result = INFINITY;
84-
qfns.vec_dot_q(test_size, &result, tmp_q1.data(), tmp_q2.data());
86+
qfns.vec_dot(test_size, &result, tmp_q1.data(), tmp_q2.data());
8587

8688
const float dot_ref = dot_product(test_data1, test_data2, test_size);
8789

@@ -123,9 +125,9 @@ int main(int argc, char * argv[]) {
123125

124126
for (int i = 0; i < GGML_TYPE_COUNT; i++) {
125127
ggml_type type = (ggml_type) i;
126-
quantize_fns_t qfns = ggml_internal_get_quantize_fn(i);
128+
ggml_type_handling_t qfns = ggml_internal_get_type_handling(type);
127129

128-
if (qfns.quantize_row_q && qfns.dequantize_row_q) {
130+
if (qfns.from_float && qfns.to_float) {
129131
const float total_error = total_quantization_error(qfns, test_size, test_data.data());
130132
const float max_quantization_error =
131133
type == GGML_TYPE_Q2_K ? MAX_QUANTIZATION_TOTAL_ERROR_2BITS :

tests/test-quantize-perf.cpp

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -123,9 +123,9 @@ void usage(char * argv[]) {
123123
printf(" --type TYPE set test type as");
124124
for (int i = 0; i < GGML_TYPE_COUNT; i++) {
125125
ggml_type type = (ggml_type) i;
126-
quantize_fns_t qfns = ggml_internal_get_quantize_fn(type);
126+
ggml_type_handling_t qfns = ggml_internal_get_type_handling(type);
127127
if (ggml_type_name(type) != NULL) {
128-
if (qfns.quantize_row_q && qfns.dequantize_row_q) {
128+
if (qfns.from_float && qfns.to_float) {
129129
printf(" %s", ggml_type_name(type));
130130
}
131131
}
@@ -271,20 +271,20 @@ int main(int argc, char * argv[]) {
271271

272272
for (int i = 0; i < GGML_TYPE_COUNT; i++) {
273273
ggml_type type = (ggml_type) i;
274-
quantize_fns_t qfns = ggml_internal_get_quantize_fn(i);
274+
ggml_type_handling_t qfns = ggml_internal_get_type_handling(type);
275275
if (!params.include_types.empty() && ggml_type_name(type) && std::find(params.include_types.begin(), params.include_types.end(), ggml_type_name(type)) == params.include_types.end()) {
276276
continue;
277277
}
278278

279-
if (qfns.quantize_row_q && qfns.dequantize_row_q) {
279+
if (qfns.from_float && qfns.to_float) {
280280
printf("%s\n", ggml_type_name(type));
281281

282282
if (params.op_quantize_row_q_reference) {
283283
printf(" quantize_row_q_reference\n");
284284
for (size_t size : params.test_sizes) {
285285
printf(" %zu values (%.2f MB)\n", size, 4*size/(float)(1024*1024));
286286
auto quantize_fn = [&](void ) {
287-
qfns.quantize_row_q_reference(test_data1, test_q1, size);
287+
qfns.from_float_reference(test_data1, test_q1, size);
288288
return test_q1[0];
289289
};
290290
size_t quantized_size = size / ggml_blck_size(type) * ggml_type_size(type);
@@ -298,7 +298,7 @@ int main(int argc, char * argv[]) {
298298
for (size_t size : params.test_sizes) {
299299
printf(" %zu values (%.2f MB)\n", size, 4*size/(float)(1024*1024));
300300
auto quantize_fn = [&](void ) {
301-
qfns.quantize_row_q(test_data1, test_q1, size);
301+
qfns.from_float(test_data1, test_q1, size);
302302
return test_q1[0];
303303
};
304304
size_t quantized_size = size / ggml_blck_size(type) * ggml_type_size(type);
@@ -309,11 +309,11 @@ int main(int argc, char * argv[]) {
309309

310310
if (params.op_dequantize_row_q) {
311311
printf(" dequantize_row_q\n");
312-
qfns.quantize_row_q(test_data1, test_q1, largest);
312+
qfns.from_float(test_data1, test_q1, largest);
313313
for (size_t size : params.test_sizes) {
314314
printf(" %zu values (%.2f MB)\n", size, 4*size/(float)(1024*1024));
315315
auto quantize_fn = [&](void ) {
316-
qfns.dequantize_row_q(test_q1, test_out, size);
316+
qfns.to_float(test_q1, test_out, size);
317317
return test_out[0];
318318
};
319319
size_t quantized_size = size / ggml_blck_size(type) * ggml_type_size(type);
@@ -327,7 +327,8 @@ int main(int argc, char * argv[]) {
327327
for (size_t size : params.test_sizes) {
328328
printf(" %zu values (%.2f MB)\n", size, 4*size/(float)(1024*1024));
329329
auto quantize_fn = [&](void ) {
330-
qfns.quantize_row_q_dot(test_data1, test_q1, size);
330+
auto vdot = ggml_internal_get_type_handling(qfns.vec_dot_type);
331+
vdot.from_float(test_data1, test_q1, size);
331332
return test_q1[0];
332333
};
333334
size_t quantized_size = size / ggml_blck_size(type) * ggml_type_size(type);
@@ -338,13 +339,13 @@ int main(int argc, char * argv[]) {
338339

339340
if (params.op_vec_dot_q) {
340341
printf(" vec_dot_q\n");
341-
qfns.quantize_row_q(test_data1, test_q1, largest);
342-
qfns.quantize_row_q(test_data2, test_q2, largest);
342+
qfns.from_float(test_data1, test_q1, largest);
343+
qfns.from_float(test_data2, test_q2, largest);
343344
for (size_t size : params.test_sizes) {
344345
printf(" %zu values (%.2f MB)\n", size, 4*size/(float)(1024*1024));
345346
auto quantize_fn = [&](void ) {
346347
float result;
347-
qfns.vec_dot_q(size, &result, test_q1, test_q2);
348+
qfns.vec_dot(size, &result, test_q1, test_q2);
348349
return result;
349350
};
350351
size_t quantized_size = size / ggml_blck_size(type) * ggml_type_size(type);

0 commit comments

Comments
 (0)