Skip to content

Commit 53c8434

Browse files
authored
Improve AVX2 for vec_dot_q4_3_q8_0 (#1138)
1 parent c6524f4 commit 53c8434

File tree

1 file changed

+5
-9
lines changed

1 file changed

+5
-9
lines changed

ggml.c

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2947,16 +2947,16 @@ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void *
29472947
#elif defined(__AVX2__)
29482948
// Initialize accumulator with zeros
29492949
__m256 acc = _mm256_setzero_ps();
2950+
float summs = 0.0f;
29502951

29512952
// Main loop
29522953
for (int i = 0; i < nb; i++) {
29532954
const __m128 d0 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 0].d));
29542955
const __m128 d1 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 1].d));
29552956
const __m256 dx = _mm256_set_m128(d1, d0);
29562957

2957-
const __m128 m0 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 0].m));
2958-
const __m128 m1 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 1].m));
2959-
const __m256 mx = _mm256_set_m128(m1, m0);
2958+
summs += GGML_FP16_TO_FP32(x[2*i + 0].m) * y[i].s0
2959+
+ GGML_FP16_TO_FP32(x[2*i + 1].m) * y[i].s1;
29602960

29612961
const __m128i bx0 = bytes_from_nibbles_16(x[2*i + 0].qs);
29622962
const __m128i bx1 = bytes_from_nibbles_16(x[2*i + 1].qs);
@@ -2965,16 +2965,12 @@ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void *
29652965
const __m256 dy = _mm256_broadcast_ss(&y[i].d);
29662966
const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
29672967

2968-
const __m256i syi = _mm256_maddubs_epi16(_mm256_set1_epi8(1), by);
2969-
const __m256 syf = sum_i16_pairs_float(syi);
2970-
29712968
const __m256 q = mul_sum_i8_pairs_float(bx, by);
29722969

2973-
const __m256 sxy = _mm256_fmadd_ps(q, dx, _mm256_mul_ps(mx, syf));
2974-
acc = _mm256_fmadd_ps(sxy, dy, acc);
2970+
acc = _mm256_fmadd_ps(q, _mm256_mul_ps(dx, dy), acc);
29752971
}
29762972

2977-
*s = hsum_float_8(acc);
2973+
*s = hsum_float_8(acc) + summs;
29782974
#else
29792975
// scalar
29802976
float sumf = 0.0;

0 commit comments

Comments
 (0)