Skip to content

Commit 184215e

Browse files
authored
ggml : fix UB in IQ2_S and IQ3_S (#6012)
1 parent 48358b2 commit 184215e

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

ggml-quants.c

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9025,7 +9025,7 @@ void ggml_vec_dot_iq2_s_q8_K(int n, float * restrict s, size_t bs, const void *
90259025
vld1_s8((const int8_t *)(iq2s_grid + (qs[7] | ((qh[ib32+1] << 2) & 0x300)))));
90269026
qs += 8;
90279027

9028-
vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[0] | (signs[1] << 16)));
9028+
vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[0] | ((uint32_t) signs[1] << 16)));
90299029
vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2);
90309030
vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2);
90319031
vs.val[0] = vceqq_u8(vs.val[0], mask2);
@@ -9034,7 +9034,7 @@ void ggml_vec_dot_iq2_s_q8_K(int n, float * restrict s, size_t bs, const void *
90349034
q2s.val[0] = vmulq_s8(vreinterpretq_s8_u8(vorrq_u8(vs.val[0], m1)), q2s.val[0]);
90359035
q2s.val[1] = vmulq_s8(vreinterpretq_s8_u8(vorrq_u8(vs.val[1], m1)), q2s.val[1]);
90369036

9037-
vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[2] | (signs[3] << 16)));
9037+
vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[2] | ((uint32_t) signs[3] << 16)));
90389038
vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2);
90399039
vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2);
90409040
vs.val[0] = vceqq_u8(vs.val[0], mask2);
@@ -9105,12 +9105,12 @@ void ggml_vec_dot_iq2_s_q8_K(int n, float * restrict s, size_t bs, const void *
91059105
iq2s_grid[qs[4] | ((qh[ib32+1] << 8) & 0x300)]);
91069106
qs += 8;
91079107

9108-
__m256i aux256 = _mm256_set1_epi32(signs[0] | (signs[1] << 16));
9108+
__m256i aux256 = _mm256_set1_epi32(signs[0] | ((uint32_t) signs[1] << 16));
91099109
aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256,mask1), mask2);
91109110
const __m256i s2_1 = _mm256_cmpeq_epi8(aux256, mask2);
91119111
const __m256i q8s_1 = _mm256_sub_epi8(_mm256_xor_si256(s2_1, q8_1), s2_1);
91129112

9113-
aux256 = _mm256_set1_epi32(signs[2] | (signs[3] << 16));
9113+
aux256 = _mm256_set1_epi32(signs[2] | ((uint32_t) signs[3] << 16));
91149114
aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256,mask1), mask2);
91159115
const __m256i s2_2 = _mm256_cmpeq_epi8(aux256, mask2);
91169116
const __m256i q8s_2 = _mm256_sub_epi8(_mm256_xor_si256(s2_2, q8_2), s2_2);
@@ -9386,7 +9386,7 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * restrict s, size_t bs, const void *
93869386
iq3s_grid[idx.index[6]], iq3s_grid[idx.index[7]]);
93879387

93889388

9389-
vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[0] | (signs[1] << 16)));
9389+
vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[0] | ((uint32_t) signs[1] << 16)));
93909390
vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2);
93919391
vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2);
93929392
vs.val[0] = vorrq_u8(vceqq_u8(vs.val[0], mask2), m1);
@@ -9395,7 +9395,7 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * restrict s, size_t bs, const void *
93959395
q3s.val[0] = vmulq_s8(vreinterpretq_s8_u8(vs.val[0]), vreinterpretq_s8_u32(aux32x4_0));
93969396
q3s.val[1] = vmulq_s8(vreinterpretq_s8_u8(vs.val[1]), vreinterpretq_s8_u32(aux32x4_1));
93979397

9398-
vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[2] | (signs[3] << 16)));
9398+
vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[2] | ((uint32_t) signs[3] << 16)));
93999399
vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2);
94009400
vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2);
94019401
vs.val[0] = vorrq_u8(vceqq_u8(vs.val[0], mask2), m1);

0 commit comments

Comments
 (0)