@@ -30,6 +30,7 @@ struct quantize_perf_params {
30
30
bool op_quantize_row_q_reference = false ;
31
31
bool op_quantize_row_q = false ;
32
32
bool op_dequantize_row_q = false ;
33
+ bool op_quantize_row_q_dot = false ;
33
34
bool op_vec_dot_q = false ;
34
35
};
35
36
@@ -147,6 +148,8 @@ int main(int argc, char * argv[]) {
147
148
params.op_quantize_row_q = true ;
148
149
} else if (op == " dequantize_row_q" ) {
149
150
params.op_dequantize_row_q = true ;
151
+ } else if (op == " quantize_row_q_dot" ) {
152
+ params.op_quantize_row_q_dot = true ;
150
153
} else if (op == " vec_dot_q" ) {
151
154
params.op_vec_dot_q = true ;
152
155
} else {
@@ -184,8 +187,8 @@ int main(int argc, char * argv[]) {
184
187
if (params.test_sizes .empty ()) {
185
188
params.test_sizes .push_back (L1_SIZE);
186
189
}
187
- if (!(params.op_quantize_row_q_reference || params.op_quantize_row_q || params.op_dequantize_row_q || params.op_vec_dot_q )) {
188
- params.op_quantize_row_q_reference = params.op_quantize_row_q = params.op_dequantize_row_q = params.op_vec_dot_q = true ;
190
+ if (!(params.op_quantize_row_q_reference || params.op_quantize_row_q || params.op_dequantize_row_q || params.op_quantize_row_q_dot || params. op_vec_dot_q )) {
191
+ params.op_quantize_row_q_reference = params.op_quantize_row_q = params.op_dequantize_row_q = params.op_quantize_row_q_dot = params. op_vec_dot_q = true ;
189
192
}
190
193
191
194
std::sort (params.test_sizes .begin (), params.test_sizes .end ());
@@ -225,7 +228,7 @@ int main(int argc, char * argv[]) {
225
228
if (qfns.quantize_row_q ) {
226
229
printf (" %s\n " , ggml_type_name (type));
227
230
228
- if (params.op_quantize_row_q_reference ) {
231
+ if (params.op_quantize_row_q_reference && qfns. quantize_row_q_reference ) {
229
232
printf (" quantize_row_q_reference\n " );
230
233
for (size_t size : params.test_sizes ) {
231
234
printf (" %zu values (%.2f MB)\n " , size, 4 *size/(float )(1024 *1024 ));
@@ -239,7 +242,7 @@ int main(int argc, char * argv[]) {
239
242
printf (" \n " );
240
243
}
241
244
242
- if (params.op_quantize_row_q ) {
245
+ if (params.op_quantize_row_q && qfns. quantize_row_q ) {
243
246
printf (" quantize_row_q\n " );
244
247
for (size_t size : params.test_sizes ) {
245
248
printf (" %zu values (%.2f MB)\n " , size, 4 *size/(float )(1024 *1024 ));
@@ -253,7 +256,7 @@ int main(int argc, char * argv[]) {
253
256
printf (" \n " );
254
257
}
255
258
256
- if (params.op_dequantize_row_q ) {
259
+ if (params.op_dequantize_row_q && qfns. dequantize_row_q ) {
257
260
printf (" dequantize_row_q\n " );
258
261
qfns.quantize_row_q (test_data1, test_q1, largest);
259
262
for (size_t size : params.test_sizes ) {
@@ -268,7 +271,21 @@ int main(int argc, char * argv[]) {
268
271
printf (" \n " );
269
272
}
270
273
271
- if (params.op_vec_dot_q ) {
274
+ if (params.op_quantize_row_q_dot && qfns.quantize_row_q_dot ) {
275
+ printf (" quantize_row_q_dot\n " );
276
+ for (size_t size : params.test_sizes ) {
277
+ printf (" %zu values (%.2f MB)\n " , size, 4 *size/(float )(1024 *1024 ));
278
+ auto quantize_fn = [&](void ) {
279
+ qfns.quantize_row_q_dot (test_data1, test_q1, size);
280
+ return test_q1[0 ];
281
+ };
282
+ size_t quantized_size = size / ggml_blck_size (type) * ggml_type_size (type);
283
+ benchmark_function (size, quantized_size, quantize_fn);
284
+ }
285
+ printf (" \n " );
286
+ }
287
+
288
+ if (params.op_vec_dot_q && qfns.vec_dot_q ) {
272
289
printf (" vec_dot_q\n " );
273
290
qfns.quantize_row_q (test_data1, test_q1, largest);
274
291
qfns.quantize_row_q (test_data2, test_q2, largest);
0 commit comments