@@ -42,7 +42,7 @@ uint64_t get_time_us() {
42
42
// naive implementation
43
43
//
44
44
45
- void mul_mat_vec_f32_0 (
45
+ void mul_mat_vec_f32_naive (
46
46
const float * restrict src0 , // M x K
47
47
const float * restrict src1 , // N x K (transposed)
48
48
float * dst ,
@@ -58,7 +58,11 @@ void mul_mat_vec_f32_0(
58
58
}
59
59
}
60
60
61
- void quantize (const float * src , void * dst , int n , int k ) {
61
+ //
62
+ // method 1
63
+ //
64
+
65
+ void quantize_1 (const float * src , void * dst , int n , int k ) {
62
66
char * p0 = dst ;
63
67
64
68
gq_t pp [QB ];
@@ -128,7 +132,7 @@ void quantize(const float * src, void * dst, int n, int k) {
128
132
}
129
133
}
130
134
131
- void mul_mat_vec_gq_0 (
135
+ void mul_mat_vec_gq_1 (
132
136
const void * src0 ,
133
137
const void * src1 ,
134
138
float * dst ,
@@ -138,6 +142,12 @@ void mul_mat_vec_gq_0(
138
142
const char * restrict p0 = src0 ;
139
143
const char * restrict p1 = src1 ;
140
144
145
+ float s0 [QB + 1 ];
146
+ float s1 [QB + 1 ];
147
+
148
+ gq_t m0 [QB + 1 ];
149
+ gq_t m1 [QB + 1 ];
150
+
141
151
for (int ir0 = 0 ; ir0 < m ; ir0 ++ ) {
142
152
for (int ir1 = 0 ; ir1 < n ; ir1 ++ ) {
143
153
float sumf = 0.0 ;
@@ -159,9 +169,6 @@ void mul_mat_vec_gq_0(
159
169
#if 1
160
170
// >>> General case for any QB
161
171
162
- float s0 [QB + 1 ];
163
- float s1 [QB + 1 ];
164
-
165
172
s0 [0 ] = min0 ;
166
173
s1 [0 ] = min1 ;
167
174
@@ -170,8 +177,146 @@ void mul_mat_vec_gq_0(
170
177
s1 [b + 1 ] = d1 * (1 << b );
171
178
}
172
179
173
- gq_t m0 [QB + 1 ];
174
- gq_t m1 [QB + 1 ];
180
+ m0 [0 ] = -1LL ;
181
+ m1 [0 ] = -1LL ;
182
+
183
+ for (int s = 0 ; s < QK /gq_t_bits ; ++ s ) {
184
+ for (int b = 0 ; b < QB ; b ++ ) {
185
+ memcpy (& m0 [b + 1 ], pp0 , sizeof (gq_t )); pp0 += sizeof (gq_t );
186
+ memcpy (& m1 [b + 1 ], pp1 , sizeof (gq_t )); pp1 += sizeof (gq_t );
187
+ }
188
+
189
+ for (int q0 = 0 ; q0 < QB + 1 ; q0 ++ ) {
190
+ for (int q1 = 0 ; q1 < QB + 1 ; q1 ++ ) {
191
+ sumf += s0 [q0 ]* s1 [q1 ]* __builtin_popcountll (m0 [q0 ] & m1 [q1 ]);
192
+ }
193
+ }
194
+ }
195
+ #else
196
+ #endif
197
+ }
198
+
199
+ dst [ir0 * n + ir1 ] = sumf ;
200
+ }
201
+ }
202
+ }
203
+
204
+ //
205
+ // method 2
206
+ //
207
+
208
+ void quantize_2 (const float * src , void * dst , int n , int k ) {
209
+ char * p0 = dst ;
210
+
211
+ for (int j = 0 ; j < n ; j ++ ) {
212
+ for (int i = 0 ; i < k /QK ; i ++ ) {
213
+ float min = FLT_MAX ;
214
+ float max = - FLT_MAX ;
215
+
216
+ // find min/max
217
+ #ifdef __ARM_NEON
218
+ {
219
+ float32x4_t minv = vdupq_n_f32 (FLT_MAX );
220
+ float32x4_t maxv = vdupq_n_f32 (- FLT_MAX );
221
+
222
+ for (int l = 0 ; l < QK ; l += 4 ) {
223
+ float32x4_t v = vld1q_f32 (src + j * k + i * QK + l );
224
+ minv = vminq_f32 (minv , v );
225
+ maxv = vmaxq_f32 (maxv , v );
226
+ }
227
+
228
+ float32x2_t minv32 = vpmin_f32 (vget_low_f32 (minv ), vget_high_f32 (minv ));
229
+ float32x2_t maxv32 = vpmax_f32 (vget_low_f32 (maxv ), vget_high_f32 (maxv ));
230
+
231
+ min = MIN (vget_lane_f32 (minv32 , 0 ), vget_lane_f32 (minv32 , 1 ));
232
+ max = MAX (vget_lane_f32 (maxv32 , 0 ), vget_lane_f32 (maxv32 , 1 ));
233
+
234
+ //printf("SIMD min/max: %f %f\n", min, max);
235
+ }
236
+ #else
237
+ {
238
+ for (int l = 0 ; l < QK ; l ++ ) {
239
+ const float v = src [j * k + i * QK + l ];
240
+ if (v < min ) min = v ;
241
+ if (v > max ) max = v ;
242
+ }
243
+
244
+ //printf("NORM min/max: %f %f\n", min, max);
245
+ }
246
+ #endif
247
+
248
+ const float d = (max - min ) / ((1 << QB ) - 1 );
249
+ const float id = d ? 1.0 /d : 0.0 ;
250
+
251
+ memcpy (p0 , & min , sizeof (float )); p0 += sizeof (float );
252
+ memcpy (p0 , & d , sizeof (float )); p0 += sizeof (float );
253
+
254
+ //printf("min/max/d/id: %f %f %f %f\n", min, max, d, id);
255
+
256
+ for (int s = 0 ; s < QK /gq_t_bits ; ++ s ) {
257
+ gq_t pp [QB ] = {0 };
258
+
259
+ for (int l = 0 ; l < gq_t_bits ; l ++ ) {
260
+ const float v = src [j * k + i * QK + s * gq_t_bits + l ];
261
+ const uint8_t q = (v - min )* id ;
262
+
263
+ for (int b = 0 ; b < QB ; b ++ ) {
264
+ pp [b ] |= q & (1 << b ) ? (1LL << l ) : 0 ;
265
+ }
266
+ }
267
+
268
+ for (int b = 0 ; b < QB ; b ++ ) {
269
+ memcpy (p0 , & pp [b ], sizeof (gq_t )); p0 += sizeof (gq_t );
270
+ }
271
+ }
272
+ }
273
+ }
274
+ }
275
+
276
+ void mul_mat_vec_gq_2 (
277
+ const void * src0 ,
278
+ const void * src1 ,
279
+ float * dst ,
280
+ int m , int n , int k ) {
281
+ const int kp = k & ~(gq_t_bits - 1 );
282
+
283
+ const char * restrict p0 = src0 ;
284
+ const char * restrict p1 = src1 ;
285
+
286
+ float s0 [QB + 1 ];
287
+ float s1 [QB + 1 ];
288
+
289
+ gq_t m0 [QB + 1 ];
290
+ gq_t m1 [QB + 1 ];
291
+
292
+ for (int ir0 = 0 ; ir0 < m ; ir0 ++ ) {
293
+ for (int ir1 = 0 ; ir1 < n ; ir1 ++ ) {
294
+ float sumf = 0.0 ;
295
+
296
+ const char * restrict pp0 = p0 + ir0 * ((2 * sizeof (float ) + (QK /gq_t_bits )* QB * sizeof (gq_t ))* (k /QK ));
297
+ const char * restrict pp1 = p1 + ir1 * ((2 * sizeof (float ) + (QK /gq_t_bits )* QB * sizeof (gq_t ))* (k /QK ));
298
+
299
+ for (int i = 0 ; i < kp /QK ; i ++ ) {
300
+ float min0 , d0 ;
301
+ memcpy (& min0 , pp0 , sizeof (float )); pp0 += sizeof (float );
302
+ memcpy (& d0 , pp0 , sizeof (float )); pp0 += sizeof (float );
303
+
304
+ float min1 , d1 ;
305
+ memcpy (& min1 , pp1 , sizeof (float )); pp1 += sizeof (float );
306
+ memcpy (& d1 , pp1 , sizeof (float )); pp1 += sizeof (float );
307
+
308
+ //printf("min0/d0 = %f %f | min1/d1 = %f %f\n", min0, d0, min1, d1);
309
+
310
+ #if 1
311
+ // >>> General case for any QB
312
+
313
+ s0 [0 ] = min0 ;
314
+ s1 [0 ] = min1 ;
315
+
316
+ for (int b = 0 ; b < QB ; b ++ ) {
317
+ s0 [b + 1 ] = d0 * (1 << b );
318
+ s1 [b + 1 ] = d1 * (1 << b );
319
+ }
175
320
176
321
m0 [0 ] = -1LL ;
177
322
m1 [0 ] = -1LL ;
@@ -198,6 +343,8 @@ void mul_mat_vec_gq_0(
198
343
}
199
344
200
345
int main (int argc , const char * * argv ) {
346
+ assert (sizeof (gq_t )* 8 == gq_t_bits );
347
+
201
348
float * src0 = (float * )malloc (sizeof (float )* M * K );
202
349
float * src1 = (float * )malloc (sizeof (float )* N * K );
203
350
float * dst = (float * )malloc (sizeof (float )* M * N );
@@ -219,20 +366,27 @@ int main(int argc, const char ** argv) {
219
366
220
367
printf ("compression: %f\n" , (float )sizegq /sizef16 );
221
368
369
+ int method = 0 ;
370
+ if (argc > 1 ) {
371
+ method = atoi (argv [1 ]);
372
+ }
373
+
222
374
// convert fp32 -> gq
223
375
{
224
376
const uint64_t t_start = get_time_us ();
225
377
226
- quantize (src0 , src0_gq , M , K );
227
- quantize (src1 , src1_gq , N , K );
378
+ if (method == 1 ) {
379
+ quantize_1 (src0 , src0_gq , M , K );
380
+ quantize_1 (src1 , src1_gq , N , K );
381
+ }
228
382
229
- const uint64_t t_end = get_time_us ();
230
- printf ("convert time: %f ms\n" , (t_end - t_start ) / 1000.0 );
231
- }
383
+ if (method == 2 ) {
384
+ quantize_2 (src0 , src0_gq , M , K );
385
+ quantize_2 (src1 , src1_gq , N , K );
386
+ }
232
387
233
- int method = 0 ;
234
- if (argc > 1 ) {
235
- method = atoi (argv [1 ]);
388
+ const uint64_t t_end = get_time_us ();
389
+ printf ("convert time: %f ms / method = %d\n" , (t_end - t_start ) / 1000.0 , method );
236
390
}
237
391
238
392
const int nIter = 1 ;
@@ -244,11 +398,15 @@ int main(int argc, const char ** argv) {
244
398
double sum = 0.0f ;
245
399
for (int i = 0 ; i < nIter ; i ++ ) {
246
400
if (method == 0 ) {
247
- mul_mat_vec_f32_0 (src0 , src1 , dst , M , N , K );
401
+ mul_mat_vec_f32_naive (src0 , src1 , dst , M , N , K );
248
402
}
249
403
250
404
if (method == 1 ) {
251
- mul_mat_vec_gq_0 (src0_gq , src1_gq , dst , M , N , K );
405
+ mul_mat_vec_gq_1 (src0_gq , src1_gq , dst , M , N , K );
406
+ }
407
+
408
+ if (method == 2 ) {
409
+ mul_mat_vec_gq_1 (src0_gq , src1_gq , dst , M , N , K );
252
410
}
253
411
}
254
412
0 commit comments