|
1 | 1 | // quantized matrix multiplication
|
2 | 2 |
|
| 3 | +#include "ggml.h" |
| 4 | + |
3 | 5 | #include <float.h>
|
4 | 6 | #include <stdint.h>
|
5 | 7 | #include <stdio.h>
|
@@ -59,6 +61,8 @@ void mul_mat_vec_f32_0(
|
59 | 61 | void quantize(const float * src, void * dst, int n, int k) {
|
60 | 62 | char * p0 = dst;
|
61 | 63 |
|
| 64 | + gq_t pp[QB]; |
| 65 | + |
62 | 66 | for (int j = 0; j < n; j++) {
|
63 | 67 | for (int i = 0; i < k/QK; i++) {
|
64 | 68 | float min = FLT_MAX;
|
@@ -105,7 +109,7 @@ void quantize(const float * src, void * dst, int n, int k) {
|
105 | 109 | //printf("min/max/d/id: %f %f %f %f\n", min, max, d, id);
|
106 | 110 |
|
107 | 111 | for (int s = 0; s < QK/gq_t_bits; ++s) {
|
108 |
| - gq_t pp[QB] = {0}; |
| 112 | + memset(pp, 0, sizeof(pp)); |
109 | 113 |
|
110 | 114 | for (int l = 0; l < gq_t_bits; l++) {
|
111 | 115 | const float v = src[j*k + i*QK + s*gq_t_bits + l];
|
@@ -209,7 +213,7 @@ int main(int argc, const char ** argv) {
|
209 | 213 | void * src0_gq = calloc(1, (2*sizeof(float) + (QK/gq_t_bits)*QB*sizeof(gq_t))*(K/QK)*M);
|
210 | 214 | void * src1_gq = calloc(1, (2*sizeof(float) + (QK/gq_t_bits)*QB*sizeof(gq_t))*(K/QK)*N);
|
211 | 215 |
|
212 |
| - const size_t sizef16 = sizeof(__fp16)*M*K + sizeof(__fp16)*N*K; |
| 216 | + const size_t sizef16 = sizeof(ggml_fp16_t)*M*K + sizeof(ggml_fp16_t)*N*K; |
213 | 217 | const size_t sizegq = (2*sizeof(float) + (QK/gq_t_bits)*QB*sizeof(gq_t))*(K/QK)*M +
|
214 | 218 | (2*sizeof(float) + (QK/gq_t_bits)*QB*sizeof(gq_t))*(K/QK)*N;
|
215 | 219 |
|
@@ -256,7 +260,7 @@ int main(int argc, const char ** argv) {
|
256 | 260 | const clock_t end = clock();
|
257 | 261 | const uint64_t end_us = get_time_us();
|
258 | 262 | printf("%s: elapsed ticks: %ld\n", __func__, end - start);
|
259 |
| - printf("%s: elapsed us: %llu / %f ms\n", __func__, end_us - start_us, (end_us - start_us) / 1000.0 / nIter); |
| 263 | + printf("%s: elapsed us: %d / %f ms\n", __func__, (int)(end_us - start_us), (end_us - start_us) / 1000.0 / nIter); |
260 | 264 | }
|
261 | 265 |
|
262 | 266 | printf("%f\n", sum);
|
|
0 commit comments