|
| 1 | +// quantized matrix multiplication |
| 2 | + |
| 3 | +#include <float.h> |
| 4 | +#include <stdint.h> |
| 5 | +#include <stdio.h> |
| 6 | +#include <assert.h> |
| 7 | +#include <stdlib.h> |
| 8 | +#include <string.h> |
| 9 | +#include <time.h> |
| 10 | +#include <math.h> |
| 11 | + |
| 12 | +#include <sys/time.h> |
| 13 | + |
| 14 | +#ifdef __ARM_NEON |
| 15 | +#include "arm_neon.h" |
| 16 | +#endif |
| 17 | + |
| 18 | +#ifndef MIN |
| 19 | +#define MAX(a, b) ((a) > (b) ? (a) : (b)) |
| 20 | +#define MIN(a, b) ((a) < (b) ? (a) : (b)) |
| 21 | +#endif |
| 22 | + |
| 23 | +const int M = 1280; |
| 24 | +const int N = 1536; |
| 25 | +const int K = 1280; |
| 26 | + |
| 27 | +const int QK = 64; |
| 28 | +const int QB = 7; |
| 29 | + |
| 30 | +#define gq_t uint64_t |
| 31 | +#define gq_t_bits 64 |
| 32 | + |
| 33 | +uint64_t get_time_us() { |
| 34 | + struct timeval tv; |
| 35 | + gettimeofday(&tv, NULL); |
| 36 | + return tv.tv_sec * 1000000 + tv.tv_usec; |
| 37 | +} |
| 38 | + |
| 39 | +// |
| 40 | +// naive implementation |
| 41 | +// |
| 42 | + |
| 43 | +void mul_mat_vec_f32_0( |
| 44 | + const float * restrict src0, // M x K |
| 45 | + const float * restrict src1, // N x K (transposed) |
| 46 | + float * dst, |
| 47 | + int m, int n, int k) { |
| 48 | + for (int i = 0; i < m; i++) { |
| 49 | + for (int j = 0; j < n; j++) { |
| 50 | + float sum = 0; |
| 51 | + for (int l = 0; l < k; l++) { |
| 52 | + sum += src0[i*k + l] * src1[j*k + l]; |
| 53 | + } |
| 54 | + dst[i*n + j] = sum; |
| 55 | + } |
| 56 | + } |
| 57 | +} |
| 58 | + |
| 59 | +void quantize(const float * src, void * dst, int n, int k) { |
| 60 | + char * p0 = dst; |
| 61 | + |
| 62 | + for (int j = 0; j < n; j++) { |
| 63 | + for (int i = 0; i < k/QK; i++) { |
| 64 | + float min = FLT_MAX; |
| 65 | + float max = -FLT_MAX; |
| 66 | + |
| 67 | + // find min/max |
| 68 | +#ifdef __ARM_NEON |
| 69 | + { |
| 70 | + float32x4_t minv = vdupq_n_f32(FLT_MAX); |
| 71 | + float32x4_t maxv = vdupq_n_f32(-FLT_MAX); |
| 72 | + |
| 73 | + for (int l = 0; l < QK; l += 4) { |
| 74 | + float32x4_t v = vld1q_f32(src + j*k + i*QK + l); |
| 75 | + minv = vminq_f32(minv, v); |
| 76 | + maxv = vmaxq_f32(maxv, v); |
| 77 | + } |
| 78 | + |
| 79 | + float32x2_t minv32 = vpmin_f32(vget_low_f32(minv), vget_high_f32(minv)); |
| 80 | + float32x2_t maxv32 = vpmax_f32(vget_low_f32(maxv), vget_high_f32(maxv)); |
| 81 | + |
| 82 | + min = MIN(vget_lane_f32(minv32, 0), vget_lane_f32(minv32, 1)); |
| 83 | + max = MAX(vget_lane_f32(maxv32, 0), vget_lane_f32(maxv32, 1)); |
| 84 | + |
| 85 | + //printf("SIMD min/max: %f %f\n", min, max); |
| 86 | + } |
| 87 | +#else |
| 88 | + { |
| 89 | + for (int l = 0; l < QK; l++) { |
| 90 | + const float v = src[j*k + i*QK + l]; |
| 91 | + if (v < min) min = v; |
| 92 | + if (v > max) max = v; |
| 93 | + } |
| 94 | + |
| 95 | + //printf("NORM min/max: %f %f\n", min, max); |
| 96 | + } |
| 97 | +#endif |
| 98 | + |
| 99 | + const float d = (max - min) / ((1 << QB) - 1); |
| 100 | + const float id = d ? 1.0/d : 0.0; |
| 101 | + |
| 102 | + memcpy(p0, &min, sizeof(float)); p0 += sizeof(float); |
| 103 | + memcpy(p0, &d, sizeof(float)); p0 += sizeof(float); |
| 104 | + |
| 105 | + //printf("min/max/d/id: %f %f %f %f\n", min, max, d, id); |
| 106 | + |
| 107 | + for (int s = 0; s < QK/gq_t_bits; ++s) { |
| 108 | + gq_t pp[QB] = {0}; |
| 109 | + |
| 110 | + for (int l = 0; l < gq_t_bits; l++) { |
| 111 | + const float v = src[j*k + i*QK + s*gq_t_bits + l]; |
| 112 | + const uint8_t q = (v - min)*id; |
| 113 | + |
| 114 | + for (int b = 0; b < QB; b++) { |
| 115 | + pp[b] |= q & (1 << b) ? (1LL << l) : 0; |
| 116 | + } |
| 117 | + } |
| 118 | + |
| 119 | + for (int b = 0; b < QB; b++) { |
| 120 | + memcpy(p0, &pp[b], sizeof(gq_t)); p0 += sizeof(gq_t); |
| 121 | + } |
| 122 | + } |
| 123 | + } |
| 124 | + } |
| 125 | +} |
| 126 | + |
| 127 | +void mul_mat_vec_gq_0( |
| 128 | + const void * src0, |
| 129 | + const void * src1, |
| 130 | + float * dst, |
| 131 | + int m, int n, int k) { |
| 132 | + const int kp = k & ~(gq_t_bits - 1); |
| 133 | + |
| 134 | + const char * restrict p0 = src0; |
| 135 | + const char * restrict p1 = src1; |
| 136 | + |
| 137 | + for (int ir0 = 0; ir0 < m; ir0++) { |
| 138 | + for (int ir1 = 0; ir1 < n; ir1++) { |
| 139 | + float sumf = 0.0; |
| 140 | + |
| 141 | + const char * restrict pp0 = p0 + ir0*((2*sizeof(float) + (QK/gq_t_bits)*QB*sizeof(gq_t))*(k/QK)); |
| 142 | + const char * restrict pp1 = p1 + ir1*((2*sizeof(float) + (QK/gq_t_bits)*QB*sizeof(gq_t))*(k/QK)); |
| 143 | + |
| 144 | + for (int i = 0; i < kp/QK; i++) { |
| 145 | + float min0, d0; |
| 146 | + memcpy(&min0, pp0, sizeof(float)); pp0 += sizeof(float); |
| 147 | + memcpy(&d0, pp0, sizeof(float)); pp0 += sizeof(float); |
| 148 | + |
| 149 | + float min1, d1; |
| 150 | + memcpy(&min1, pp1, sizeof(float)); pp1 += sizeof(float); |
| 151 | + memcpy(&d1, pp1, sizeof(float)); pp1 += sizeof(float); |
| 152 | + |
| 153 | + //printf("min0/d0 = %f %f | min1/d1 = %f %f\n", min0, d0, min1, d1); |
| 154 | + |
| 155 | +#if 1 |
| 156 | + // >>> General case for any QB |
| 157 | + |
| 158 | + float s0[QB + 1]; |
| 159 | + float s1[QB + 1]; |
| 160 | + |
| 161 | + s0[0] = min0; |
| 162 | + s1[0] = min1; |
| 163 | + |
| 164 | + for (int b = 0; b < QB; b++) { |
| 165 | + s0[b + 1] = d0*(1 << b); |
| 166 | + s1[b + 1] = d1*(1 << b); |
| 167 | + } |
| 168 | + |
| 169 | + gq_t m0[QB + 1]; |
| 170 | + gq_t m1[QB + 1]; |
| 171 | + |
| 172 | + m0[0] = -1LL; |
| 173 | + m1[0] = -1LL; |
| 174 | + |
| 175 | + for (int s = 0; s < QK/gq_t_bits; ++s) { |
| 176 | + for (int b = 0; b < QB; b++) { |
| 177 | + memcpy(&m0[b + 1], pp0, sizeof(gq_t)); pp0 += sizeof(gq_t); |
| 178 | + memcpy(&m1[b + 1], pp1, sizeof(gq_t)); pp1 += sizeof(gq_t); |
| 179 | + } |
| 180 | + |
| 181 | + for (int q0 = 0; q0 < QB + 1; q0++) { |
| 182 | + for (int q1 = 0; q1 < QB + 1; q1++) { |
| 183 | + sumf += s0[q0]*s1[q1]*__builtin_popcountll(m0[q0] & m1[q1]); |
| 184 | + } |
| 185 | + } |
| 186 | + } |
| 187 | +#else |
| 188 | +#endif |
| 189 | + } |
| 190 | + |
| 191 | + dst[ir0*n + ir1] = sumf; |
| 192 | + } |
| 193 | + } |
| 194 | +} |
| 195 | + |
| 196 | +int main(int argc, const char ** argv) { |
| 197 | + float * src0 = (float *)malloc(sizeof(float)*M*K); |
| 198 | + float * src1 = (float *)malloc(sizeof(float)*N*K); |
| 199 | + float * dst = (float *)malloc(sizeof(float)*M*N); |
| 200 | + |
| 201 | + for (int i = 0; i < M*K; i++) { |
| 202 | + src0[i] = rand() / (float)RAND_MAX; |
| 203 | + } |
| 204 | + |
| 205 | + for (int i = 0; i < N*K; i++) { |
| 206 | + src1[i] = rand() / (float)RAND_MAX; |
| 207 | + } |
| 208 | + |
| 209 | + void * src0_gq = calloc(1, (2*sizeof(float) + (QK/gq_t_bits)*QB*sizeof(gq_t))*(K/QK)*M); |
| 210 | + void * src1_gq = calloc(1, (2*sizeof(float) + (QK/gq_t_bits)*QB*sizeof(gq_t))*(K/QK)*N); |
| 211 | + |
| 212 | + const size_t sizef16 = sizeof(__fp16)*M*K + sizeof(__fp16)*N*K; |
| 213 | + const size_t sizegq = (2*sizeof(float) + (QK/gq_t_bits)*QB*sizeof(gq_t))*(K/QK)*M + |
| 214 | + (2*sizeof(float) + (QK/gq_t_bits)*QB*sizeof(gq_t))*(K/QK)*N; |
| 215 | + |
| 216 | + printf("compression: %f\n", (float)sizegq/sizef16); |
| 217 | + |
| 218 | + // convert fp32 -> gq |
| 219 | + { |
| 220 | + const uint64_t t_start = get_time_us(); |
| 221 | + |
| 222 | + quantize(src0, src0_gq, M, K); |
| 223 | + quantize(src1, src1_gq, N, K); |
| 224 | + |
| 225 | + const uint64_t t_end = get_time_us(); |
| 226 | + printf("convert time: %f ms\n", (t_end - t_start) / 1000.0); |
| 227 | + } |
| 228 | + |
| 229 | + int method = 0; |
| 230 | + if (argc > 1) { |
| 231 | + method = atoi(argv[1]); |
| 232 | + } |
| 233 | + |
| 234 | + const int nIter = 1; |
| 235 | + |
| 236 | + const clock_t start = clock(); |
| 237 | + const uint64_t start_us = get_time_us(); |
| 238 | + |
| 239 | + double iM = 1.0/M; |
| 240 | + double sum = 0.0f; |
| 241 | + for (int i = 0; i < nIter; i++) { |
| 242 | + if (method == 0) { |
| 243 | + mul_mat_vec_f32_0(src0, src1, dst, M, N, K); |
| 244 | + } |
| 245 | + |
| 246 | + if (method == 1) { |
| 247 | + mul_mat_vec_gq_0(src0_gq, src1_gq, dst, M, N, K); |
| 248 | + } |
| 249 | + } |
| 250 | + |
| 251 | + for (int i = 0; i < N; i++) { |
| 252 | + sum += dst[i]*iM; |
| 253 | + } |
| 254 | + |
| 255 | + { |
| 256 | + const clock_t end = clock(); |
| 257 | + const uint64_t end_us = get_time_us(); |
| 258 | + printf("%s: elapsed ticks: %ld\n", __func__, end - start); |
| 259 | + printf("%s: elapsed us: %llu / %f ms\n", __func__, end_us - start_us, (end_us - start_us) / 1000.0 / nIter); |
| 260 | + } |
| 261 | + |
| 262 | + printf("%f\n", sum); |
| 263 | + |
| 264 | + free(src0); |
| 265 | + free(src1); |
| 266 | + free(dst); |
| 267 | + |
| 268 | + free(src0_gq); |
| 269 | + free(src1_gq); |
| 270 | + |
| 271 | + return 0; |
| 272 | +} |
0 commit comments