Skip to content

Commit c542d5a

Browse files
committed
Cleaning up
1 parent 66a865b commit c542d5a

File tree

1 file changed

+25
-51
lines changed

1 file changed

+25
-51
lines changed

ggml.c

Lines changed: 25 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1310,6 +1310,29 @@ static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * r
13101310
}
13111311
}
13121312

1313+
#ifdef __AVX2__
1314+
// There is no better way of doing this?
1315+
// I guess not, AVX is not very good at horizontal sums.
1316+
// The commented solution for a hotrizontal sum was suggested by @pubby as being slightly
1317+
// faster than the solution below. As I don't have an AVX2 system handt right now to test,
1318+
// keeping the original.
1319+
// TODO: Please try and if it does make a differece, uncomment and remove the implementation below.
1320+
//static inline float horizontal_sum(__m256i a) {
1321+
// __m256i b = _mm256_castps_si256(_mm256_movehdup_ps(_mm256_castsi256_ps(a)));
1322+
// __m256i sum = _mm256_add_epi32(a, b);
1323+
// __m256i hi = _mm256_unpackhi_epi64(sum, sum);
1324+
// sum = _mm256_add_epi32(sum, hi);
1325+
// return _mm256_cvtsi256_si32(sum) + _mm256_extract_epi32(sum, 4);
1326+
//}
1327+
static inline float horizontal_sum(__m256i a) {
1328+
__m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extracti128_si256(a, 1));
1329+
__m128i hi64 = _mm_unpackhi_epi64(sum128, sum128);
1330+
__m128i sum64 = _mm_add_epi32(hi64, sum128);
1331+
__m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
1332+
return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
1333+
}
1334+
#endif
1335+
13131336
static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int k) {
13141337
assert(k % QK8_0 == 0);
13151338
const int nb = k / QK8_0;
@@ -1399,14 +1422,8 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
13991422

14001423
#if defined(__AVX2__)
14011424

1402-
// Compute the sum of the quants
1403-
// There is not better way of doing this???
1404-
__m256i acc = _mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3));
1405-
__m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(acc), _mm256_extracti128_si256(acc, 1));
1406-
__m128i hi64 = _mm_unpackhi_epi64(sum128, sum128);
1407-
__m128i sum64 = _mm_add_epi32(hi64, sum128);
1408-
__m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
1409-
y[i].s = d * _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
1425+
// Compute the sum of the quants and set y[i].s
1426+
y[i].s = d * horizontal_sum(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3)));
14101427

14111428
// Convert int32 to int16
14121429
i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
@@ -2411,7 +2428,6 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
24112428
sum8 += x0->d * y0->s + x1->d * y1->s;
24122429

24132430
const uint8x16_t m4b = vdupq_n_u8(0xf);
2414-
//const int8x16_t s8b = vdupq_n_s8(0x8);
24152431

24162432
const uint8x16_t v0_0 = vld1q_u8(x0->qs);
24172433
const uint8x16_t v0_1 = vld1q_u8(x1->qs);
@@ -2422,12 +2438,6 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
24222438
const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
24232439
const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
24242440

2425-
// sub 8
2426-
//const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
2427-
//const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
2428-
//const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
2429-
//const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
2430-
24312441
// load y
24322442
const int8x16_t v1_0l = vld1q_s8(y0->qs);
24332443
const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
@@ -2442,27 +2452,17 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
24422452

24432453
#if defined(__ARM_FEATURE_DOTPROD)
24442454
// dot product into int32x4_t
2445-
//const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls), v0_0hs, v1_0hs);
2446-
//const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls), v0_1hs, v1_1hs);
24472455
const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0ls), v0_0h, v1_0hs);
24482456
const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1ls), v0_1h, v1_1hs);
24492457

24502458
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), x0->d*y0->d);
24512459
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), x1->d*y1->d);
24522460
#else
2453-
//const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
2454-
//const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls));
2455-
//const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs));
2456-
//const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs));
24572461
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0l), vget_low_s8 (v1_0ls));
24582462
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0l), vget_high_s8(v1_0ls));
24592463
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0h), vget_low_s8 (v1_0hs));
24602464
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0h), vget_high_s8(v1_0hs));
24612465

2462-
//const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1ls));
2463-
//const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1ls));
2464-
//const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs));
2465-
//const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs));
24662466
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1l), vget_low_s8 (v1_1ls));
24672467
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1l), vget_high_s8(v1_1ls));
24682468
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1h), vget_low_s8 (v1_1hs));
@@ -2644,19 +2644,6 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
26442644
const int8x16_t v1_1ls = vuzp1q_s8(v1_1l, v1_1h);
26452645
const int8x16_t v1_1hs = vuzp2q_s8(v1_1l, v1_1h);
26462646

2647-
// We no longer need this. We have computed the sum of the y quants during quantization,
2648-
// so we get the same as these via the scalar instruction above (summs += x0->m * y0->s + x1->m * y1->s)
2649-
//const int16x8_t s0i = vaddq_s16(
2650-
// vaddq_s16(vmovl_s8(vget_low_s8(v1_0ls)), vmovl_s8(vget_high_s8(v1_0ls))),
2651-
// vaddq_s16(vmovl_s8(vget_low_s8(v1_0hs)), vmovl_s8(vget_high_s8(v1_0hs))));
2652-
2653-
//const int16x8_t s1i = vaddq_s16(
2654-
// vaddq_s16(vmovl_s8(vget_low_s8(v1_1ls)), vmovl_s8(vget_high_s8(v1_1ls))),
2655-
// vaddq_s16(vmovl_s8(vget_low_s8(v1_1hs)), vmovl_s8(vget_high_s8(v1_1hs))));
2656-
2657-
//sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddl_s16(vget_low_s16(s0i), vget_high_s16(s0i))), x0->m*y0->d);
2658-
//sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddl_s16(vget_low_s16(s1i), vget_high_s16(s1i))), x1->m*y1->d);
2659-
26602647
#if defined(__ARM_FEATURE_DOTPROD)
26612648
// dot product into int32x4_t
26622649
const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0ls), v0_0h, v1_0hs);
@@ -2702,11 +2689,9 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
27022689

27032690
const __m256 d0v = _mm256_broadcast_ss( d0 );
27042691
const __m256 d1v = _mm256_broadcast_ss( d1 );
2705-
//const __m256 m0v = _mm256_broadcast_ss( m0 );
27062692

27072693
// Compute combined scales
27082694
const __m256 d0d1 = _mm256_mul_ps( d0v, d1v );
2709-
//const __m256 d1m0 = _mm256_mul_ps( d1v, m0v );
27102695

27112696
// Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
27122697
const __m256i bx = bytes_from_nibbles_32(x[i].qs);
@@ -2728,17 +2713,6 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
27282713

27292714
// Accumulate d0*d1*x*y
27302715
acc = _mm256_fmadd_ps( d0d1, xy, acc );
2731-
2732-
// We no longer need this. We have computed the sum of the y quants during quantization,
2733-
// so we get the same as these via the single scalar instruction above (summs += x[i].m * y[i].s)
2734-
//// Compute sum of y values
2735-
//const __m256i y16_l = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( by ) );
2736-
//const __m256i y16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( by, 1 ) );
2737-
//const __m256i ysumi = _mm256_madd_epi16( _mm256_add_epi16(y16_l, y16_h), ones );
2738-
//const __m256 ysum = _mm256_cvtepi32_ps( ysumi );
2739-
2740-
//// Accumulate d1*m0*y
2741-
//acc = _mm256_fmadd_ps( d1m0, ysum, acc );
27422716
}
27432717

27442718
// Return horizontal sum of the acc vector

0 commit comments

Comments
 (0)