@@ -3391,7 +3391,6 @@ void quantize_row_tq2_0_ref(const float * restrict x, block_tq2_0 * restrict y,
3391
3391
3392
3392
y[i].d = GGML_FP32_TO_FP16(d);
3393
3393
3394
- // TODO: should it be along 64 bytes instead for AVX512?
3395
3394
for (size_t j = 0; j < sizeof(y->qs); j += 32) {
3396
3395
for (size_t m = 0; m < 32; ++m) {
3397
3396
uint8_t q = 0;
@@ -5957,7 +5956,67 @@ void ggml_vec_dot_tq2_0_q8_K(int n, float * restrict s, size_t bs, const void *
5957
5956
5958
5957
const int nb = n / QK_K;
5959
5958
5960
- #if defined __ARM_NEON
5959
+ #if defined __ARM_NEON && defined __ARM_FEATURE_DOTPROD
5960
+ float sumf = 0.0f;
5961
+
5962
+ const uint8x16_t m3 = vdupq_n_u8(3);
5963
+
5964
+ for (int i = 0; i < nb; ++i) {
5965
+ int32x4_t sumi0 = vdupq_n_s32(0);
5966
+ int32x4_t sumi1 = vdupq_n_s32(0);
5967
+
5968
+ for (size_t j = 0; j < sizeof(x->qs); j += 32) {
5969
+ uint8x16_t qx0 = vld1q_u8(x[i].qs + j);
5970
+ uint8x16_t qx1 = vld1q_u8(x[i].qs + j + 16);
5971
+ uint8x16_t qx2 = vshrq_n_u8(qx0, 2);
5972
+ uint8x16_t qx3 = vshrq_n_u8(qx1, 2);
5973
+ uint8x16_t qx4 = vshrq_n_u8(qx0, 4);
5974
+ uint8x16_t qx5 = vshrq_n_u8(qx1, 4);
5975
+ uint8x16_t qx6 = vshrq_n_u8(qx0, 6);
5976
+ uint8x16_t qx7 = vshrq_n_u8(qx1, 6);
5977
+
5978
+ int8x16_t sqx0 = vreinterpretq_s8_u8(vandq_u8(qx0, m3));
5979
+ int8x16_t sqx1 = vreinterpretq_s8_u8(vandq_u8(qx1, m3));
5980
+ int8x16_t sqx2 = vreinterpretq_s8_u8(vandq_u8(qx2, m3));
5981
+ int8x16_t sqx3 = vreinterpretq_s8_u8(vandq_u8(qx3, m3));
5982
+ int8x16_t sqx4 = vreinterpretq_s8_u8(vandq_u8(qx4, m3));
5983
+ int8x16_t sqx5 = vreinterpretq_s8_u8(vandq_u8(qx5, m3));
5984
+ int8x16_t sqx6 = vreinterpretq_s8_u8(vandq_u8(qx6, m3));
5985
+ int8x16_t sqx7 = vreinterpretq_s8_u8(vandq_u8(qx7, m3));
5986
+
5987
+ const int8x16_t qy0 = vld1q_s8(y[i].qs + j*4 + 0);
5988
+ const int8x16_t qy1 = vld1q_s8(y[i].qs + j*4 + 16);
5989
+ const int8x16_t qy2 = vld1q_s8(y[i].qs + j*4 + 32);
5990
+ const int8x16_t qy3 = vld1q_s8(y[i].qs + j*4 + 48);
5991
+ const int8x16_t qy4 = vld1q_s8(y[i].qs + j*4 + 64);
5992
+ const int8x16_t qy5 = vld1q_s8(y[i].qs + j*4 + 80);
5993
+ const int8x16_t qy6 = vld1q_s8(y[i].qs + j*4 + 96);
5994
+ const int8x16_t qy7 = vld1q_s8(y[i].qs + j*4 + 112);
5995
+
5996
+ sumi0 = vdotq_s32(sumi0, sqx0, qy0);
5997
+ sumi1 = vdotq_s32(sumi1, sqx1, qy1);
5998
+ sumi0 = vdotq_s32(sumi0, sqx2, qy2);
5999
+ sumi1 = vdotq_s32(sumi1, sqx3, qy3);
6000
+ sumi0 = vdotq_s32(sumi0, sqx4, qy4);
6001
+ sumi1 = vdotq_s32(sumi1, sqx5, qy5);
6002
+ sumi0 = vdotq_s32(sumi0, sqx6, qy6);
6003
+ sumi1 = vdotq_s32(sumi1, sqx7, qy7);
6004
+ }
6005
+
6006
+ const int16x8_t ysum0 = vld1q_s16(y[i].bsums);
6007
+ const int16x8_t ysum1 = vld1q_s16(y[i].bsums + 8);
6008
+
6009
+ sumi0 = vaddq_s32(sumi0, sumi1);
6010
+ sumi0 = vsubq_s32(sumi0, vpaddlq_s16(vaddq_s16(ysum0, ysum1)));
6011
+
6012
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
6013
+
6014
+ sumf += d * (float) vaddvq_s32(sumi0);
6015
+ }
6016
+
6017
+ *s = sumf;
6018
+
6019
+ #elif defined __ARM_NEON
5961
6020
float sumf = 0.0f;
5962
6021
5963
6022
const uint8x16_t m3 = vdupq_n_u8(3);
0 commit comments