@@ -6828,6 +6828,7 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r
6828
6828
6829
6829
int bit = 0;
6830
6830
int is = 0;
6831
+ __m256i xvbit;
6831
6832
6832
6833
const uint8_t * restrict q3 = x[i].qs;
6833
6834
const int8_t * restrict q8 = y[i].qs;
@@ -6836,21 +6837,25 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r
6836
6837
// load low 2 bits
6837
6838
const __m256i q3bits = __lasx_xvld((const __m256i*)q3, 0); q3 += 32;
6838
6839
6840
+ xvbit = __lasx_xvreplgr2vr_h(bit);
6839
6841
// prepare low and high bits
6840
6842
const __m256i q3l_0 = __lasx_xvand_v(q3bits, m3);
6841
- const __m256i q3h_0 = __lasx_xvslli_h(__lasx_xvsrli_h (__lasx_xvandn_v(hbits, __lasx_xvslli_h (mone, bit )), bit ), 2);
6843
+ const __m256i q3h_0 = __lasx_xvslli_h(__lasx_xvsrl_h (__lasx_xvandn_v(hbits, __lasx_xvsll_h (mone, xvbit )), xvbit ), 2);
6842
6844
++bit;
6843
6845
6846
+ xvbit = __lasx_xvreplgr2vr_h(bit);
6844
6847
const __m256i q3l_1 = __lasx_xvand_v(__lasx_xvsrli_h(q3bits, 2), m3);
6845
- const __m256i q3h_1 = __lasx_xvslli_h(__lasx_xvsrli_h (__lasx_xvandn_v(hbits, __lasx_xvslli_h (mone, bit )), bit ), 2);
6848
+ const __m256i q3h_1 = __lasx_xvslli_h(__lasx_xvsrl_h (__lasx_xvandn_v(hbits, __lasx_xvsll_h (mone, xvbit )), xvbit ), 2);
6846
6849
++bit;
6847
6850
6851
+ xvbit = __lasx_xvreplgr2vr_h(bit);
6848
6852
const __m256i q3l_2 = __lasx_xvand_v(__lasx_xvsrli_h(q3bits, 4), m3);
6849
- const __m256i q3h_2 = __lasx_xvslli_h(__lasx_xvsrli_h (__lasx_xvandn_v(hbits, __lasx_xvslli_h (mone, bit )), bit ), 2);
6853
+ const __m256i q3h_2 = __lasx_xvslli_h(__lasx_xvsrl_h (__lasx_xvandn_v(hbits, __lasx_xvsll_h (mone, xvbit )), xvbit ), 2);
6850
6854
++bit;
6851
6855
6856
+ xvbit = __lasx_xvreplgr2vr_h(bit);
6852
6857
const __m256i q3l_3 = __lasx_xvand_v(__lasx_xvsrli_h(q3bits, 6), m3);
6853
- const __m256i q3h_3 = __lasx_xvslli_h(__lasx_xvsrli_h (__lasx_xvandn_v(hbits, __lasx_xvslli_h (mone, bit )), bit ), 2);
6858
+ const __m256i q3h_3 = __lasx_xvslli_h(__lasx_xvsrl_h (__lasx_xvandn_v(hbits, __lasx_xvsll_h (mone, xvbit )), xvbit ), 2);
6854
6859
++bit;
6855
6860
6856
6861
// load Q8 quants
@@ -8033,6 +8038,7 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
8033
8038
__m256i sumi = __lasx_xvldi(0);
8034
8039
8035
8040
int bit = 0;
8041
+ __m256i xvbit;
8036
8042
8037
8043
for (int j = 0; j < QK_K/64; ++j) {
8038
8044
@@ -8041,13 +8047,15 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
8041
8047
8042
8048
const __m256i q5bits = __lasx_xvld((const __m256i*)q5, 0); q5 += 32;
8043
8049
8050
+ xvbit = __lasx_xvreplgr2vr_h(bit++);
8044
8051
const __m256i q5l_0 = __lasx_xvand_v(q5bits, m4);
8045
- const __m256i q5h_0 = __lasx_xvslli_h(__lasx_xvsrli_h (__lasx_xvand_v(hbits, hmask), bit++ ), 4);
8052
+ const __m256i q5h_0 = __lasx_xvslli_h(__lasx_xvsrl_h (__lasx_xvand_v(hbits, hmask), xvbit ), 4);
8046
8053
const __m256i q5_0 = __lasx_xvadd_b(q5l_0, q5h_0);
8047
8054
hmask = __lasx_xvslli_h(hmask, 1);
8048
8055
8056
+ xvbit = __lasx_xvreplgr2vr_h(bit++);
8049
8057
const __m256i q5l_1 = __lasx_xvand_v(__lasx_xvsrli_h(q5bits, 4), m4);
8050
- const __m256i q5h_1 = __lasx_xvslli_h(__lasx_xvsrli_h (__lasx_xvand_v(hbits, hmask), bit++ ), 4);
8058
+ const __m256i q5h_1 = __lasx_xvslli_h(__lasx_xvsrl_h (__lasx_xvand_v(hbits, hmask), xvbit ), 4);
8051
8059
const __m256i q5_1 = __lasx_xvadd_b(q5l_1, q5h_1);
8052
8060
hmask = __lasx_xvslli_h(hmask, 1);
8053
8061
0 commit comments