Skip to content

Commit 96b3d41

Browse files
committed
ggml-quants : allow using vdotq_s32 in TQ2_0 vec_dot
Not yet tested on harware which supports it, might not work or might not even compile. But also it might. It should make the performance better on recent ARM CPUs. * ggml-quants : remove comment about possible format change of TQ2_0 Making it slightly more convenient for AVX512 but less convenient for everything else is not worth the trouble.
1 parent f034aa1 commit 96b3d41

File tree

1 file changed

+61
-2
lines changed

1 file changed

+61
-2
lines changed

ggml/src/ggml-quants.c

Lines changed: 61 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3391,7 +3391,6 @@ void quantize_row_tq2_0_ref(const float * restrict x, block_tq2_0 * restrict y,
33913391

33923392
y[i].d = GGML_FP32_TO_FP16(d);
33933393

3394-
// TODO: should it be along 64 bytes instead for AVX512?
33953394
for (size_t j = 0; j < sizeof(y->qs); j += 32) {
33963395
for (size_t m = 0; m < 32; ++m) {
33973396
uint8_t q = 0;
@@ -5957,7 +5956,67 @@ void ggml_vec_dot_tq2_0_q8_K(int n, float * restrict s, size_t bs, const void *
59575956

59585957
const int nb = n / QK_K;
59595958

5960-
#if defined __ARM_NEON
5959+
#if defined __ARM_NEON && defined __ARM_FEATURE_DOTPROD
5960+
float sumf = 0.0f;
5961+
5962+
const uint8x16_t m3 = vdupq_n_u8(3);
5963+
5964+
for (int i = 0; i < nb; ++i) {
5965+
int32x4_t sumi0 = vdupq_n_s32(0);
5966+
int32x4_t sumi1 = vdupq_n_s32(0);
5967+
5968+
for (size_t j = 0; j < sizeof(x->qs); j += 32) {
5969+
uint8x16_t qx0 = vld1q_u8(x[i].qs + j);
5970+
uint8x16_t qx1 = vld1q_u8(x[i].qs + j + 16);
5971+
uint8x16_t qx2 = vshrq_n_u8(qx0, 2);
5972+
uint8x16_t qx3 = vshrq_n_u8(qx1, 2);
5973+
uint8x16_t qx4 = vshrq_n_u8(qx0, 4);
5974+
uint8x16_t qx5 = vshrq_n_u8(qx1, 4);
5975+
uint8x16_t qx6 = vshrq_n_u8(qx0, 6);
5976+
uint8x16_t qx7 = vshrq_n_u8(qx1, 6);
5977+
5978+
int8x16_t sqx0 = vreinterpretq_s8_u8(vandq_u8(qx0, m3));
5979+
int8x16_t sqx1 = vreinterpretq_s8_u8(vandq_u8(qx1, m3));
5980+
int8x16_t sqx2 = vreinterpretq_s8_u8(vandq_u8(qx2, m3));
5981+
int8x16_t sqx3 = vreinterpretq_s8_u8(vandq_u8(qx3, m3));
5982+
int8x16_t sqx4 = vreinterpretq_s8_u8(vandq_u8(qx4, m3));
5983+
int8x16_t sqx5 = vreinterpretq_s8_u8(vandq_u8(qx5, m3));
5984+
int8x16_t sqx6 = vreinterpretq_s8_u8(vandq_u8(qx6, m3));
5985+
int8x16_t sqx7 = vreinterpretq_s8_u8(vandq_u8(qx7, m3));
5986+
5987+
const int8x16_t qy0 = vld1q_s8(y[i].qs + j*4 + 0);
5988+
const int8x16_t qy1 = vld1q_s8(y[i].qs + j*4 + 16);
5989+
const int8x16_t qy2 = vld1q_s8(y[i].qs + j*4 + 32);
5990+
const int8x16_t qy3 = vld1q_s8(y[i].qs + j*4 + 48);
5991+
const int8x16_t qy4 = vld1q_s8(y[i].qs + j*4 + 64);
5992+
const int8x16_t qy5 = vld1q_s8(y[i].qs + j*4 + 80);
5993+
const int8x16_t qy6 = vld1q_s8(y[i].qs + j*4 + 96);
5994+
const int8x16_t qy7 = vld1q_s8(y[i].qs + j*4 + 112);
5995+
5996+
sumi0 = vdotq_s32(sumi0, sqx0, qy0);
5997+
sumi1 = vdotq_s32(sumi1, sqx1, qy1);
5998+
sumi0 = vdotq_s32(sumi0, sqx2, qy2);
5999+
sumi1 = vdotq_s32(sumi1, sqx3, qy3);
6000+
sumi0 = vdotq_s32(sumi0, sqx4, qy4);
6001+
sumi1 = vdotq_s32(sumi1, sqx5, qy5);
6002+
sumi0 = vdotq_s32(sumi0, sqx6, qy6);
6003+
sumi1 = vdotq_s32(sumi1, sqx7, qy7);
6004+
}
6005+
6006+
const int16x8_t ysum0 = vld1q_s16(y[i].bsums);
6007+
const int16x8_t ysum1 = vld1q_s16(y[i].bsums + 8);
6008+
6009+
sumi0 = vaddq_s32(sumi0, sumi1);
6010+
sumi0 = vsubq_s32(sumi0, vpaddlq_s16(vaddq_s16(ysum0, ysum1)));
6011+
6012+
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
6013+
6014+
sumf += d * (float) vaddvq_s32(sumi0);
6015+
}
6016+
6017+
*s = sumf;
6018+
6019+
#elif defined __ARM_NEON
59616020
float sumf = 0.0f;
59626021

59636022
const uint8x16_t m3 = vdupq_n_u8(3);

0 commit comments

Comments
 (0)