@@ -230,6 +230,12 @@ static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
230
230
231
231
return _mm_packus_epi16( bytes1, bytes2);
232
232
}
233
+
234
+ static inline __m128i mul_add_epi8_sse(const __m128i x, const __m128i y) {
235
+ const __m128i ax = _mm_sign_epi8(x, x);
236
+ const __m128i sy = _mm_sign_epi8(y, x);
237
+ return _mm_maddubs_epi16(ax, sy);
238
+ }
233
239
#endif
234
240
#elif defined(__SSSE3__)
235
241
// horizontally add 4x4 floats
@@ -4206,37 +4212,37 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
4206
4212
4207
4213
sumf = hsum_float_8(acc);
4208
4214
#elif defined(__AVX__)
4209
- // Initialize accumulator with zeros
4210
- __m256 acc = _mm256_setzero_ps();
4211
-
4212
- // Main loop
4213
- for (; ib < nb; ++ib) {
4214
- // Compute combined scale for the block
4215
- const __m256 d = _mm256_set1_ps( GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d) );
4216
-
4217
- const __m128i lowMask = _mm_set1_epi8(0xF);
4218
- const __m128i off = _mm_set1_epi8(8);
4219
-
4220
- const __m128i tmp = _mm_loadu_si128((const __m128i *)x[ib].qs);
4221
-
4222
- __m128i bx_0 = _mm_and_si128(lowMask, tmp);
4223
- __m128i by_0 = _mm_loadu_si128((const __m128i *)y[ib].qs);
4224
- bx_0 = _mm_sub_epi8(bx_0, off);
4225
- const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0);
4226
-
4227
- bx_0 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp, 4));
4228
- by_0 = _mm_loadu_si128((const __m128i *)(y[ib].qs + 16));
4229
- bx_0 = _mm_sub_epi8(bx_0, off);
4230
- const __m128i i32_1 = mul_sum_i8_pairs(bx_0, by_0);
4215
+ const __m128i mone = _mm_set1_epi16(1);
4231
4216
4232
- // Convert int32_t to float
4233
- __m256 p = _mm256_cvtepi32_ps(MM256_SET_M128I(i32_0, i32_1));
4217
+ __m256 accum1 = _mm256_setzero_ps();
4218
+ __m256 accum2 = _mm256_setzero_ps();
4219
+ for (; ib + 1 < nb; ib += 2) {
4220
+ const __m128i q4bits_1 = _mm_loadu_si128((const __m128i *)x[ib + 0].qs);
4221
+ const __m128i q4bits_2 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs);
4222
+ const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs);
4223
+ const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs + 1);
4224
+ const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs);
4225
+ const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs + 1);
4234
4226
4235
- // Apply the scale, and accumulate
4236
- acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc);
4227
+ const __m128i q4b_1_0 = _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), q4bits_1), _mm_set1_epi8(8));
4228
+ const __m128i q4b_1_1 = _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(q4bits_1, 4)), _mm_set1_epi8(8));
4229
+ const __m128i q4b_2_0 = _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), q4bits_2), _mm_set1_epi8(8));
4230
+ const __m128i q4b_2_1 = _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(q4bits_2, 4)), _mm_set1_epi8(8));
4231
+ const __m128i p16_1_0 = mul_add_epi8_sse(q4b_1_0, q8b_1_0);
4232
+ const __m128i p16_1_1 = mul_add_epi8_sse(q4b_1_1, q8b_1_1);
4233
+ const __m128i p16_2_0 = mul_add_epi8_sse(q4b_2_0, q8b_2_0);
4234
+ const __m128i p16_2_1 = mul_add_epi8_sse(q4b_2_1, q8b_2_1);
4235
+ const __m128i p_1_0 = _mm_madd_epi16(p16_1_0, mone);
4236
+ const __m128i p_1_1 = _mm_madd_epi16(p16_1_1, mone);
4237
+ const __m128i p_2_0 = _mm_madd_epi16(p16_2_0, mone);
4238
+ const __m128i p_2_1 = _mm_madd_epi16(p16_2_1, mone);
4239
+ accum1 = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(y[ib + 0].d)*GGML_FP16_TO_FP32(x[ib + 0].d)),
4240
+ _mm256_cvtepi32_ps(MM256_SET_M128I(p_1_1, p_1_0))), accum1);
4241
+ accum2 = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(y[ib + 1].d)*GGML_FP16_TO_FP32(x[ib + 1].d)),
4242
+ _mm256_cvtepi32_ps(MM256_SET_M128I(p_2_1, p_2_0))), accum2);
4237
4243
}
4238
4244
4239
- sumf = hsum_float_8(acc );
4245
+ sumf = hsum_float_8(_mm256_add_ps(accum1, accum2) );
4240
4246
#elif defined(__SSSE3__)
4241
4247
// set constants
4242
4248
const __m128i lowMask = _mm_set1_epi8(0xF);
@@ -11819,15 +11825,6 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * restrict s, size_t bs, const void *
11819
11825
#endif
11820
11826
}
11821
11827
11822
-
11823
- #if defined(__AVX__)
11824
- static inline __m128i mul_add_epi8_sse(const __m128i x, const __m128i y) {
11825
- const __m128i ax = _mm_sign_epi8(x, x);
11826
- const __m128i sy = _mm_sign_epi8(y, x);
11827
- return _mm_maddubs_epi16(ax, sy);
11828
- }
11829
- #endif
11830
-
11831
11828
#if defined(__AVX2__)
11832
11829
static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) {
11833
11830
const __m256i ax = _mm256_sign_epi8(x, x);
0 commit comments