Skip to content

Commit 446ccf3

Browse files
committed
tests : experiments with n-bit quantized matrix multiplication
1 parent bd9f710 commit 446ccf3

File tree

4 files changed

+287
-6
lines changed

4 files changed

+287
-6
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@ compile_commands.json
99
.DS_Store
1010

1111
src/arm_neon.h
12+
tests/arm_neon.h

tests/CMakeLists.txt

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,14 @@ if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm" AND NOT GGML_NO_ACCELERATE)
6565
add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}>)
6666
endif()
6767

68+
#
69+
# test-mul-mat2
70+
71+
set(TEST_TARGET test-mul-mat2)
72+
add_executable(${TEST_TARGET} ${TEST_TARGET}.c)
73+
target_link_libraries(${TEST_TARGET} PRIVATE ggml)
74+
add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}>)
75+
6876
#
6977
# test0
7078

tests/test-mul-mat1.c

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,12 @@ const int M = 1280;
1616
const int N = 1500;
1717
const int K = 1280;
1818

19+
uint64_t get_time_us() {
20+
struct timeval tv;
21+
gettimeofday(&tv, NULL);
22+
return tv.tv_sec * 1000000 + tv.tv_usec;
23+
}
24+
1925
//
2026
// naive implementation
2127
//
@@ -206,12 +212,6 @@ void mul_mat_vec_f8_0(
206212
}
207213
}
208214

209-
uint64_t get_time_us() {
210-
struct timeval tv;
211-
gettimeofday(&tv, NULL);
212-
return tv.tv_sec * 1000000 + tv.tv_usec;
213-
}
214-
215215
int main(int argc, const char ** argv) {
216216
float * src0 = (float *)malloc(sizeof(float)*M*K);
217217
float * src1 = (float *)malloc(sizeof(float)*N*K);

tests/test-mul-mat2.c

Lines changed: 272 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,272 @@
1+
// quantized matrix multiplication
2+
3+
#include <float.h>
4+
#include <stdint.h>
5+
#include <stdio.h>
6+
#include <assert.h>
7+
#include <stdlib.h>
8+
#include <string.h>
9+
#include <time.h>
10+
#include <math.h>
11+
12+
#include <sys/time.h>
13+
14+
#ifdef __ARM_NEON
15+
#include "arm_neon.h"
16+
#endif
17+
18+
#ifndef MIN
19+
#define MAX(a, b) ((a) > (b) ? (a) : (b))
20+
#define MIN(a, b) ((a) < (b) ? (a) : (b))
21+
#endif
22+
23+
const int M = 1280;
24+
const int N = 1536;
25+
const int K = 1280;
26+
27+
const int QK = 64;
28+
const int QB = 7;
29+
30+
#define gq_t uint64_t
31+
#define gq_t_bits 64
32+
33+
uint64_t get_time_us() {
34+
struct timeval tv;
35+
gettimeofday(&tv, NULL);
36+
return tv.tv_sec * 1000000 + tv.tv_usec;
37+
}
38+
39+
//
40+
// naive implementation
41+
//
42+
43+
void mul_mat_vec_f32_0(
44+
const float * restrict src0, // M x K
45+
const float * restrict src1, // N x K (transposed)
46+
float * dst,
47+
int m, int n, int k) {
48+
for (int i = 0; i < m; i++) {
49+
for (int j = 0; j < n; j++) {
50+
float sum = 0;
51+
for (int l = 0; l < k; l++) {
52+
sum += src0[i*k + l] * src1[j*k + l];
53+
}
54+
dst[i*n + j] = sum;
55+
}
56+
}
57+
}
58+
59+
void quantize(const float * src, void * dst, int n, int k) {
60+
char * p0 = dst;
61+
62+
for (int j = 0; j < n; j++) {
63+
for (int i = 0; i < k/QK; i++) {
64+
float min = FLT_MAX;
65+
float max = -FLT_MAX;
66+
67+
// find min/max
68+
#ifdef __ARM_NEON
69+
{
70+
float32x4_t minv = vdupq_n_f32(FLT_MAX);
71+
float32x4_t maxv = vdupq_n_f32(-FLT_MAX);
72+
73+
for (int l = 0; l < QK; l += 4) {
74+
float32x4_t v = vld1q_f32(src + j*k + i*QK + l);
75+
minv = vminq_f32(minv, v);
76+
maxv = vmaxq_f32(maxv, v);
77+
}
78+
79+
float32x2_t minv32 = vpmin_f32(vget_low_f32(minv), vget_high_f32(minv));
80+
float32x2_t maxv32 = vpmax_f32(vget_low_f32(maxv), vget_high_f32(maxv));
81+
82+
min = MIN(vget_lane_f32(minv32, 0), vget_lane_f32(minv32, 1));
83+
max = MAX(vget_lane_f32(maxv32, 0), vget_lane_f32(maxv32, 1));
84+
85+
//printf("SIMD min/max: %f %f\n", min, max);
86+
}
87+
#else
88+
{
89+
for (int l = 0; l < QK; l++) {
90+
const float v = src[j*k + i*QK + l];
91+
if (v < min) min = v;
92+
if (v > max) max = v;
93+
}
94+
95+
//printf("NORM min/max: %f %f\n", min, max);
96+
}
97+
#endif
98+
99+
const float d = (max - min) / ((1 << QB) - 1);
100+
const float id = d ? 1.0/d : 0.0;
101+
102+
memcpy(p0, &min, sizeof(float)); p0 += sizeof(float);
103+
memcpy(p0, &d, sizeof(float)); p0 += sizeof(float);
104+
105+
//printf("min/max/d/id: %f %f %f %f\n", min, max, d, id);
106+
107+
for (int s = 0; s < QK/gq_t_bits; ++s) {
108+
gq_t pp[QB] = {0};
109+
110+
for (int l = 0; l < gq_t_bits; l++) {
111+
const float v = src[j*k + i*QK + s*gq_t_bits + l];
112+
const uint8_t q = (v - min)*id;
113+
114+
for (int b = 0; b < QB; b++) {
115+
pp[b] |= q & (1 << b) ? (1LL << l) : 0;
116+
}
117+
}
118+
119+
for (int b = 0; b < QB; b++) {
120+
memcpy(p0, &pp[b], sizeof(gq_t)); p0 += sizeof(gq_t);
121+
}
122+
}
123+
}
124+
}
125+
}
126+
127+
void mul_mat_vec_gq_0(
128+
const void * src0,
129+
const void * src1,
130+
float * dst,
131+
int m, int n, int k) {
132+
const int kp = k & ~(gq_t_bits - 1);
133+
134+
const char * restrict p0 = src0;
135+
const char * restrict p1 = src1;
136+
137+
for (int ir0 = 0; ir0 < m; ir0++) {
138+
for (int ir1 = 0; ir1 < n; ir1++) {
139+
float sumf = 0.0;
140+
141+
const char * restrict pp0 = p0 + ir0*((2*sizeof(float) + (QK/gq_t_bits)*QB*sizeof(gq_t))*(k/QK));
142+
const char * restrict pp1 = p1 + ir1*((2*sizeof(float) + (QK/gq_t_bits)*QB*sizeof(gq_t))*(k/QK));
143+
144+
for (int i = 0; i < kp/QK; i++) {
145+
float min0, d0;
146+
memcpy(&min0, pp0, sizeof(float)); pp0 += sizeof(float);
147+
memcpy(&d0, pp0, sizeof(float)); pp0 += sizeof(float);
148+
149+
float min1, d1;
150+
memcpy(&min1, pp1, sizeof(float)); pp1 += sizeof(float);
151+
memcpy(&d1, pp1, sizeof(float)); pp1 += sizeof(float);
152+
153+
//printf("min0/d0 = %f %f | min1/d1 = %f %f\n", min0, d0, min1, d1);
154+
155+
#if 1
156+
// >>> General case for any QB
157+
158+
float s0[QB + 1];
159+
float s1[QB + 1];
160+
161+
s0[0] = min0;
162+
s1[0] = min1;
163+
164+
for (int b = 0; b < QB; b++) {
165+
s0[b + 1] = d0*(1 << b);
166+
s1[b + 1] = d1*(1 << b);
167+
}
168+
169+
gq_t m0[QB + 1];
170+
gq_t m1[QB + 1];
171+
172+
m0[0] = -1LL;
173+
m1[0] = -1LL;
174+
175+
for (int s = 0; s < QK/gq_t_bits; ++s) {
176+
for (int b = 0; b < QB; b++) {
177+
memcpy(&m0[b + 1], pp0, sizeof(gq_t)); pp0 += sizeof(gq_t);
178+
memcpy(&m1[b + 1], pp1, sizeof(gq_t)); pp1 += sizeof(gq_t);
179+
}
180+
181+
for (int q0 = 0; q0 < QB + 1; q0++) {
182+
for (int q1 = 0; q1 < QB + 1; q1++) {
183+
sumf += s0[q0]*s1[q1]*__builtin_popcountll(m0[q0] & m1[q1]);
184+
}
185+
}
186+
}
187+
#else
188+
#endif
189+
}
190+
191+
dst[ir0*n + ir1] = sumf;
192+
}
193+
}
194+
}
195+
196+
int main(int argc, const char ** argv) {
197+
float * src0 = (float *)malloc(sizeof(float)*M*K);
198+
float * src1 = (float *)malloc(sizeof(float)*N*K);
199+
float * dst = (float *)malloc(sizeof(float)*M*N);
200+
201+
for (int i = 0; i < M*K; i++) {
202+
src0[i] = rand() / (float)RAND_MAX;
203+
}
204+
205+
for (int i = 0; i < N*K; i++) {
206+
src1[i] = rand() / (float)RAND_MAX;
207+
}
208+
209+
void * src0_gq = calloc(1, (2*sizeof(float) + (QK/gq_t_bits)*QB*sizeof(gq_t))*(K/QK)*M);
210+
void * src1_gq = calloc(1, (2*sizeof(float) + (QK/gq_t_bits)*QB*sizeof(gq_t))*(K/QK)*N);
211+
212+
const size_t sizef16 = sizeof(__fp16)*M*K + sizeof(__fp16)*N*K;
213+
const size_t sizegq = (2*sizeof(float) + (QK/gq_t_bits)*QB*sizeof(gq_t))*(K/QK)*M +
214+
(2*sizeof(float) + (QK/gq_t_bits)*QB*sizeof(gq_t))*(K/QK)*N;
215+
216+
printf("compression: %f\n", (float)sizegq/sizef16);
217+
218+
// convert fp32 -> gq
219+
{
220+
const uint64_t t_start = get_time_us();
221+
222+
quantize(src0, src0_gq, M, K);
223+
quantize(src1, src1_gq, N, K);
224+
225+
const uint64_t t_end = get_time_us();
226+
printf("convert time: %f ms\n", (t_end - t_start) / 1000.0);
227+
}
228+
229+
int method = 0;
230+
if (argc > 1) {
231+
method = atoi(argv[1]);
232+
}
233+
234+
const int nIter = 1;
235+
236+
const clock_t start = clock();
237+
const uint64_t start_us = get_time_us();
238+
239+
double iM = 1.0/M;
240+
double sum = 0.0f;
241+
for (int i = 0; i < nIter; i++) {
242+
if (method == 0) {
243+
mul_mat_vec_f32_0(src0, src1, dst, M, N, K);
244+
}
245+
246+
if (method == 1) {
247+
mul_mat_vec_gq_0(src0_gq, src1_gq, dst, M, N, K);
248+
}
249+
}
250+
251+
for (int i = 0; i < N; i++) {
252+
sum += dst[i]*iM;
253+
}
254+
255+
{
256+
const clock_t end = clock();
257+
const uint64_t end_us = get_time_us();
258+
printf("%s: elapsed ticks: %ld\n", __func__, end - start);
259+
printf("%s: elapsed us: %llu / %f ms\n", __func__, end_us - start_us, (end_us - start_us) / 1000.0 / nIter);
260+
}
261+
262+
printf("%f\n", sum);
263+
264+
free(src0);
265+
free(src1);
266+
free(dst);
267+
268+
free(src0_gq);
269+
free(src1_gq);
270+
271+
return 0;
272+
}

0 commit comments

Comments
 (0)