@@ -6088,6 +6088,7 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * r
6088
6088
6089
6089
const uint8_t * restrict q2 = x[i].qs;
6090
6090
const int8_t * restrict q8 = y[i].qs;
6091
+
6091
6092
const __m128i mins_and_scales = __lsx_vld((const __m128i*)x[i].scales, 0);
6092
6093
const __m128i scales8 = __lsx_vand_v(mins_and_scales, m4);
6093
6094
const __m128i mins8 = __lsx_vand_v(__lsx_vsrli_h(mins_and_scales, 4), m4);
@@ -6807,6 +6808,8 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r
6807
6808
for (int i = 0; i < nb; ++i) {
6808
6809
6809
6810
const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
6811
+ const uint8_t * restrict q3 = x[i].qs;
6812
+ const int8_t * restrict q8 = y[i].qs;
6810
6813
// Set up scales
6811
6814
memcpy(aux, x[i].scales, 12);
6812
6815
__m128i scales128 = lsx_set_w(
@@ -6830,8 +6833,6 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r
6830
6833
int is = 0;
6831
6834
__m256i xvbit;
6832
6835
6833
- const uint8_t * restrict q3 = x[i].qs;
6834
- const int8_t * restrict q8 = y[i].qs;
6835
6836
6836
6837
for (int j = 0; j < QK_K/128; ++j) {
6837
6838
// load low 2 bits
@@ -7404,6 +7405,9 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r
7404
7405
*s = vec_extract(vsumf0, 0);
7405
7406
7406
7407
#elif defined __loongarch_asx
7408
+ GGML_UNUSED(kmask1);
7409
+ GGML_UNUSED(kmask2);
7410
+ GGML_UNUSED(kmask3);
7407
7411
7408
7412
const __m256i m4 = __lasx_xvreplgr2vr_b(0xF);
7409
7413
@@ -7416,6 +7420,11 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r
7416
7420
const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
7417
7421
7418
7422
memcpy(utmp, x[i].scales, 12);
7423
+ utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
7424
+ const uint32_t uaux = utmp[1] & kmask1;
7425
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
7426
+ utmp[2] = uaux;
7427
+ utmp[0] &= kmask1;
7419
7428
7420
7429
const uint8_t * restrict q4 = x[i].qs;
7421
7430
const int8_t * restrict q8 = y[i].qs;
@@ -7455,16 +7464,17 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r
7455
7464
7456
7465
__m256 vd = __lasx_xvreplfr2vr_s(d);
7457
7466
acc = __lasx_xvfmadd_s(vd, __lasx_xvffint_s_w(sumi), acc);
7467
+
7458
7468
}
7459
7469
7460
7470
acc_m = __lsx_vfadd_s(acc_m, (__m128)__lsx_vpermi_w((__m128i)acc_m, (__m128i)acc_m, 0xee));
7461
7471
__m128i tmp1 = __lsx_vinsgr2vr_w(__lsx_vldi(0), __lsx_vpickve2gr_w((__m128i)acc_m, 1), 0);
7462
7472
acc_m = __lsx_vfadd_s(acc_m, (__m128)tmp1);
7463
7473
7474
+
7464
7475
ft_union fi;
7465
7476
fi.i = __lsx_vpickve2gr_w(acc_m, 0);
7466
7477
*s = hsum_float_8(acc) + fi.f ;
7467
-
7468
7478
#else
7469
7479
7470
7480
const uint8_t * scales = (const uint8_t*)&utmp[0];
@@ -8002,6 +8012,9 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
8002
8012
*s = vec_extract(vsumf0, 0);
8003
8013
8004
8014
#elif defined __loongarch_asx
8015
+ GGML_UNUSED(kmask1);
8016
+ GGML_UNUSED(kmask2);
8017
+ GGML_UNUSED(kmask3);
8005
8018
8006
8019
const __m256i m4 = __lasx_xvreplgr2vr_b(0xF);
8007
8020
const __m128i mzero = __lsx_vldi(0);
@@ -8020,6 +8033,11 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
8020
8033
const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
8021
8034
8022
8035
memcpy(utmp, x[i].scales, 12);
8036
+ utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
8037
+ const uint32_t uaux = utmp[1] & kmask1;
8038
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
8039
+ utmp[2] = uaux;
8040
+ utmp[0] &= kmask1;
8023
8041
8024
8042
const __m256i mins_and_scales = lasx_extu8_16(lsx_set_w(utmp[3], utmp[2], utmp[1], utmp[0]));
8025
8043
@@ -8069,10 +8087,12 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
8069
8087
p16_1 = lasx_madd_h(scale_1, p16_1);
8070
8088
8071
8089
sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_0, p16_1));
8090
+
8072
8091
}
8073
8092
8074
8093
__m256 vd = __lasx_xvreplfr2vr_s(d);
8075
8094
acc = __lasx_xvfmadd_s(vd, __lasx_xvffint_s_w(sumi), acc);
8095
+
8076
8096
}
8077
8097
8078
8098
*s = hsum_float_8(acc) + summs;
0 commit comments