Skip to content

Commit deb0c48

Browse files
committed
tests : wip quantized matrix multiplication method 2
1 parent d677c7f commit deb0c48

File tree

2 files changed

+177
-18
lines changed

2 files changed

+177
-18
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ endif()
4747

4848
#set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -ffast-math")
4949
#set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -march=native")
50+
#set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mcpu=native")
5051

5152
# dependencies
5253

tests/test-mul-mat2.c

Lines changed: 176 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ uint64_t get_time_us() {
4242
// naive implementation
4343
//
4444

45-
void mul_mat_vec_f32_0(
45+
void mul_mat_vec_f32_naive(
4646
const float * restrict src0, // M x K
4747
const float * restrict src1, // N x K (transposed)
4848
float * dst,
@@ -58,7 +58,11 @@ void mul_mat_vec_f32_0(
5858
}
5959
}
6060

61-
void quantize(const float * src, void * dst, int n, int k) {
61+
//
62+
// method 1
63+
//
64+
65+
void quantize_1(const float * src, void * dst, int n, int k) {
6266
char * p0 = dst;
6367

6468
gq_t pp[QB];
@@ -128,7 +132,7 @@ void quantize(const float * src, void * dst, int n, int k) {
128132
}
129133
}
130134

131-
void mul_mat_vec_gq_0(
135+
void mul_mat_vec_gq_1(
132136
const void * src0,
133137
const void * src1,
134138
float * dst,
@@ -138,6 +142,12 @@ void mul_mat_vec_gq_0(
138142
const char * restrict p0 = src0;
139143
const char * restrict p1 = src1;
140144

145+
float s0[QB + 1];
146+
float s1[QB + 1];
147+
148+
gq_t m0[QB + 1];
149+
gq_t m1[QB + 1];
150+
141151
for (int ir0 = 0; ir0 < m; ir0++) {
142152
for (int ir1 = 0; ir1 < n; ir1++) {
143153
float sumf = 0.0;
@@ -159,9 +169,6 @@ void mul_mat_vec_gq_0(
159169
#if 1
160170
// >>> General case for any QB
161171

162-
float s0[QB + 1];
163-
float s1[QB + 1];
164-
165172
s0[0] = min0;
166173
s1[0] = min1;
167174

@@ -170,8 +177,146 @@ void mul_mat_vec_gq_0(
170177
s1[b + 1] = d1*(1 << b);
171178
}
172179

173-
gq_t m0[QB + 1];
174-
gq_t m1[QB + 1];
180+
m0[0] = -1LL;
181+
m1[0] = -1LL;
182+
183+
for (int s = 0; s < QK/gq_t_bits; ++s) {
184+
for (int b = 0; b < QB; b++) {
185+
memcpy(&m0[b + 1], pp0, sizeof(gq_t)); pp0 += sizeof(gq_t);
186+
memcpy(&m1[b + 1], pp1, sizeof(gq_t)); pp1 += sizeof(gq_t);
187+
}
188+
189+
for (int q0 = 0; q0 < QB + 1; q0++) {
190+
for (int q1 = 0; q1 < QB + 1; q1++) {
191+
sumf += s0[q0]*s1[q1]*__builtin_popcountll(m0[q0] & m1[q1]);
192+
}
193+
}
194+
}
195+
#else
196+
#endif
197+
}
198+
199+
dst[ir0*n + ir1] = sumf;
200+
}
201+
}
202+
}
203+
204+
//
205+
// method 2
206+
//
207+
208+
void quantize_2(const float * src, void * dst, int n, int k) {
209+
char * p0 = dst;
210+
211+
for (int j = 0; j < n; j++) {
212+
for (int i = 0; i < k/QK; i++) {
213+
float min = FLT_MAX;
214+
float max = -FLT_MAX;
215+
216+
// find min/max
217+
#ifdef __ARM_NEON
218+
{
219+
float32x4_t minv = vdupq_n_f32(FLT_MAX);
220+
float32x4_t maxv = vdupq_n_f32(-FLT_MAX);
221+
222+
for (int l = 0; l < QK; l += 4) {
223+
float32x4_t v = vld1q_f32(src + j*k + i*QK + l);
224+
minv = vminq_f32(minv, v);
225+
maxv = vmaxq_f32(maxv, v);
226+
}
227+
228+
float32x2_t minv32 = vpmin_f32(vget_low_f32(minv), vget_high_f32(minv));
229+
float32x2_t maxv32 = vpmax_f32(vget_low_f32(maxv), vget_high_f32(maxv));
230+
231+
min = MIN(vget_lane_f32(minv32, 0), vget_lane_f32(minv32, 1));
232+
max = MAX(vget_lane_f32(maxv32, 0), vget_lane_f32(maxv32, 1));
233+
234+
//printf("SIMD min/max: %f %f\n", min, max);
235+
}
236+
#else
237+
{
238+
for (int l = 0; l < QK; l++) {
239+
const float v = src[j*k + i*QK + l];
240+
if (v < min) min = v;
241+
if (v > max) max = v;
242+
}
243+
244+
//printf("NORM min/max: %f %f\n", min, max);
245+
}
246+
#endif
247+
248+
const float d = (max - min) / ((1 << QB) - 1);
249+
const float id = d ? 1.0/d : 0.0;
250+
251+
memcpy(p0, &min, sizeof(float)); p0 += sizeof(float);
252+
memcpy(p0, &d, sizeof(float)); p0 += sizeof(float);
253+
254+
//printf("min/max/d/id: %f %f %f %f\n", min, max, d, id);
255+
256+
for (int s = 0; s < QK/gq_t_bits; ++s) {
257+
gq_t pp[QB] = {0};
258+
259+
for (int l = 0; l < gq_t_bits; l++) {
260+
const float v = src[j*k + i*QK + s*gq_t_bits + l];
261+
const uint8_t q = (v - min)*id;
262+
263+
for (int b = 0; b < QB; b++) {
264+
pp[b] |= q & (1 << b) ? (1LL << l) : 0;
265+
}
266+
}
267+
268+
for (int b = 0; b < QB; b++) {
269+
memcpy(p0, &pp[b], sizeof(gq_t)); p0 += sizeof(gq_t);
270+
}
271+
}
272+
}
273+
}
274+
}
275+
276+
void mul_mat_vec_gq_2(
277+
const void * src0,
278+
const void * src1,
279+
float * dst,
280+
int m, int n, int k) {
281+
const int kp = k & ~(gq_t_bits - 1);
282+
283+
const char * restrict p0 = src0;
284+
const char * restrict p1 = src1;
285+
286+
float s0[QB + 1];
287+
float s1[QB + 1];
288+
289+
gq_t m0[QB + 1];
290+
gq_t m1[QB + 1];
291+
292+
for (int ir0 = 0; ir0 < m; ir0++) {
293+
for (int ir1 = 0; ir1 < n; ir1++) {
294+
float sumf = 0.0;
295+
296+
const char * restrict pp0 = p0 + ir0*((2*sizeof(float) + (QK/gq_t_bits)*QB*sizeof(gq_t))*(k/QK));
297+
const char * restrict pp1 = p1 + ir1*((2*sizeof(float) + (QK/gq_t_bits)*QB*sizeof(gq_t))*(k/QK));
298+
299+
for (int i = 0; i < kp/QK; i++) {
300+
float min0, d0;
301+
memcpy(&min0, pp0, sizeof(float)); pp0 += sizeof(float);
302+
memcpy(&d0, pp0, sizeof(float)); pp0 += sizeof(float);
303+
304+
float min1, d1;
305+
memcpy(&min1, pp1, sizeof(float)); pp1 += sizeof(float);
306+
memcpy(&d1, pp1, sizeof(float)); pp1 += sizeof(float);
307+
308+
//printf("min0/d0 = %f %f | min1/d1 = %f %f\n", min0, d0, min1, d1);
309+
310+
#if 1
311+
// >>> General case for any QB
312+
313+
s0[0] = min0;
314+
s1[0] = min1;
315+
316+
for (int b = 0; b < QB; b++) {
317+
s0[b + 1] = d0*(1 << b);
318+
s1[b + 1] = d1*(1 << b);
319+
}
175320

176321
m0[0] = -1LL;
177322
m1[0] = -1LL;
@@ -198,6 +343,8 @@ void mul_mat_vec_gq_0(
198343
}
199344

200345
int main(int argc, const char ** argv) {
346+
assert(sizeof(gq_t)*8 == gq_t_bits);
347+
201348
float * src0 = (float *)malloc(sizeof(float)*M*K);
202349
float * src1 = (float *)malloc(sizeof(float)*N*K);
203350
float * dst = (float *)malloc(sizeof(float)*M*N);
@@ -219,20 +366,27 @@ int main(int argc, const char ** argv) {
219366

220367
printf("compression: %f\n", (float)sizegq/sizef16);
221368

369+
int method = 0;
370+
if (argc > 1) {
371+
method = atoi(argv[1]);
372+
}
373+
222374
// convert fp32 -> gq
223375
{
224376
const uint64_t t_start = get_time_us();
225377

226-
quantize(src0, src0_gq, M, K);
227-
quantize(src1, src1_gq, N, K);
378+
if (method == 1) {
379+
quantize_1(src0, src0_gq, M, K);
380+
quantize_1(src1, src1_gq, N, K);
381+
}
228382

229-
const uint64_t t_end = get_time_us();
230-
printf("convert time: %f ms\n", (t_end - t_start) / 1000.0);
231-
}
383+
if (method == 2) {
384+
quantize_2(src0, src0_gq, M, K);
385+
quantize_2(src1, src1_gq, N, K);
386+
}
232387

233-
int method = 0;
234-
if (argc > 1) {
235-
method = atoi(argv[1]);
388+
const uint64_t t_end = get_time_us();
389+
printf("convert time: %f ms / method = %d\n", (t_end - t_start) / 1000.0, method);
236390
}
237391

238392
const int nIter = 1;
@@ -244,11 +398,15 @@ int main(int argc, const char ** argv) {
244398
double sum = 0.0f;
245399
for (int i = 0; i < nIter; i++) {
246400
if (method == 0) {
247-
mul_mat_vec_f32_0(src0, src1, dst, M, N, K);
401+
mul_mat_vec_f32_naive(src0, src1, dst, M, N, K);
248402
}
249403

250404
if (method == 1) {
251-
mul_mat_vec_gq_0(src0_gq, src1_gq, dst, M, N, K);
405+
mul_mat_vec_gq_1(src0_gq, src1_gq, dst, M, N, K);
406+
}
407+
408+
if (method == 2) {
409+
mul_mat_vec_gq_1(src0_gq, src1_gq, dst, M, N, K);
252410
}
253411
}
254412

0 commit comments

Comments
 (0)