Skip to content

Commit 5c3d0f1

Browse files
authored
ggml : IQ4_NL sgemm + Q4_0 AVX optimization (#9422)
* squashed readd my iq4_nl sgemm PR #8049 have ggml_vec_dot_q4_0 do two blocks per loop for avx try out f16c ggml_vec_dot_iq4_nl, but it's not really faster. as per #8549 we can calculate several blocks at a time with no issue * shuffle * remove f16c iq4_nl as i cant make it faster than before
1 parent 0aadac1 commit 5c3d0f1

File tree

2 files changed

+71
-36
lines changed

2 files changed

+71
-36
lines changed

ggml/src/ggml-quants.c

Lines changed: 33 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,12 @@ static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
230230

231231
return _mm_packus_epi16( bytes1, bytes2);
232232
}
233+
234+
static inline __m128i mul_add_epi8_sse(const __m128i x, const __m128i y) {
235+
const __m128i ax = _mm_sign_epi8(x, x);
236+
const __m128i sy = _mm_sign_epi8(y, x);
237+
return _mm_maddubs_epi16(ax, sy);
238+
}
233239
#endif
234240
#elif defined(__SSSE3__)
235241
// horizontally add 4x4 floats
@@ -4206,37 +4212,37 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
42064212

42074213
sumf = hsum_float_8(acc);
42084214
#elif defined(__AVX__)
4209-
// Initialize accumulator with zeros
4210-
__m256 acc = _mm256_setzero_ps();
4211-
4212-
// Main loop
4213-
for (; ib < nb; ++ib) {
4214-
// Compute combined scale for the block
4215-
const __m256 d = _mm256_set1_ps( GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d) );
4216-
4217-
const __m128i lowMask = _mm_set1_epi8(0xF);
4218-
const __m128i off = _mm_set1_epi8(8);
4219-
4220-
const __m128i tmp = _mm_loadu_si128((const __m128i *)x[ib].qs);
4221-
4222-
__m128i bx_0 = _mm_and_si128(lowMask, tmp);
4223-
__m128i by_0 = _mm_loadu_si128((const __m128i *)y[ib].qs);
4224-
bx_0 = _mm_sub_epi8(bx_0, off);
4225-
const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0);
4226-
4227-
bx_0 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp, 4));
4228-
by_0 = _mm_loadu_si128((const __m128i *)(y[ib].qs + 16));
4229-
bx_0 = _mm_sub_epi8(bx_0, off);
4230-
const __m128i i32_1 = mul_sum_i8_pairs(bx_0, by_0);
4215+
const __m128i mone = _mm_set1_epi16(1);
42314216

4232-
// Convert int32_t to float
4233-
__m256 p = _mm256_cvtepi32_ps(MM256_SET_M128I(i32_0, i32_1));
4217+
__m256 accum1 = _mm256_setzero_ps();
4218+
__m256 accum2 = _mm256_setzero_ps();
4219+
for (; ib + 1 < nb; ib += 2) {
4220+
const __m128i q4bits_1 = _mm_loadu_si128((const __m128i *)x[ib + 0].qs);
4221+
const __m128i q4bits_2 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs);
4222+
const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs);
4223+
const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs + 1);
4224+
const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs);
4225+
const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs + 1);
42344226

4235-
// Apply the scale, and accumulate
4236-
acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc);
4227+
const __m128i q4b_1_0 = _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), q4bits_1), _mm_set1_epi8(8));
4228+
const __m128i q4b_1_1 = _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(q4bits_1, 4)), _mm_set1_epi8(8));
4229+
const __m128i q4b_2_0 = _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), q4bits_2), _mm_set1_epi8(8));
4230+
const __m128i q4b_2_1 = _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(q4bits_2, 4)), _mm_set1_epi8(8));
4231+
const __m128i p16_1_0 = mul_add_epi8_sse(q4b_1_0, q8b_1_0);
4232+
const __m128i p16_1_1 = mul_add_epi8_sse(q4b_1_1, q8b_1_1);
4233+
const __m128i p16_2_0 = mul_add_epi8_sse(q4b_2_0, q8b_2_0);
4234+
const __m128i p16_2_1 = mul_add_epi8_sse(q4b_2_1, q8b_2_1);
4235+
const __m128i p_1_0 = _mm_madd_epi16(p16_1_0, mone);
4236+
const __m128i p_1_1 = _mm_madd_epi16(p16_1_1, mone);
4237+
const __m128i p_2_0 = _mm_madd_epi16(p16_2_0, mone);
4238+
const __m128i p_2_1 = _mm_madd_epi16(p16_2_1, mone);
4239+
accum1 = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(y[ib + 0].d)*GGML_FP16_TO_FP32(x[ib + 0].d)),
4240+
_mm256_cvtepi32_ps(MM256_SET_M128I(p_1_1, p_1_0))), accum1);
4241+
accum2 = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(y[ib + 1].d)*GGML_FP16_TO_FP32(x[ib + 1].d)),
4242+
_mm256_cvtepi32_ps(MM256_SET_M128I(p_2_1, p_2_0))), accum2);
42374243
}
42384244

4239-
sumf = hsum_float_8(acc);
4245+
sumf = hsum_float_8(_mm256_add_ps(accum1, accum2));
42404246
#elif defined(__SSSE3__)
42414247
// set constants
42424248
const __m128i lowMask = _mm_set1_epi8(0xF);
@@ -11819,15 +11825,6 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * restrict s, size_t bs, const void *
1181911825
#endif
1182011826
}
1182111827

11822-
11823-
#if defined(__AVX__)
11824-
static inline __m128i mul_add_epi8_sse(const __m128i x, const __m128i y) {
11825-
const __m128i ax = _mm_sign_epi8(x, x);
11826-
const __m128i sy = _mm_sign_epi8(y, x);
11827-
return _mm_maddubs_epi16(ax, sy);
11828-
}
11829-
#endif
11830-
1183111828
#if defined(__AVX2__)
1183211829
static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) {
1183311830
const __m256i ax = _mm256_sign_epi8(x, x);

ggml/src/llamafile/sgemm.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,14 @@ template <> inline __m512 load(const ggml_fp16_t *p) {
235235
}
236236
#endif // __AVX512F__
237237

238+
////////////////////////////////////////////////////////////////////////////////////////////////////
239+
// CONSTANTS
240+
241+
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
242+
static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
243+
static const __m128i iq4nlt = _mm_loadu_si128((const __m128i *) kvalues_iq4nl);
244+
#endif
245+
238246
////////////////////////////////////////////////////////////////////////////////////////////////////
239247
// FLOATING POINT MATRIX MULTIPLICATION
240248

@@ -933,6 +941,20 @@ class tinyBLAS_Q0_AVX {
933941
return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4)), _mm_set1_epi8(8));
934942
}
935943

944+
inline __m256i load(const block_iq4_nl *b) {
945+
return MM256_SET_M128I(load1(b), load0(b));
946+
}
947+
948+
inline __m128i load0(const block_iq4_nl *b) {
949+
const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
950+
return _mm_shuffle_epi8(iq4nlt, _mm_and_si128(_mm_set1_epi8(15), x));
951+
}
952+
953+
inline __m128i load1(const block_iq4_nl *b) {
954+
const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
955+
return _mm_shuffle_epi8(iq4nlt, _mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4)));
956+
}
957+
936958
inline __m256 updot(__m256i u, __m256i s) {
937959
__m256i res;
938960
#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__))
@@ -1159,6 +1181,22 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
11591181
#endif
11601182
}
11611183

1184+
case GGML_TYPE_IQ4_NL: {
1185+
if (Btype != GGML_TYPE_Q8_0)
1186+
return false;
1187+
#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
1188+
tinyBLAS_Q0_AVX<block_iq4_nl, block_q8_0, float> tb{
1189+
k, (const block_iq4_nl *)A, lda,
1190+
(const block_q8_0 *)B, ldb,
1191+
(float *)C, ldc,
1192+
ith, nth};
1193+
tb.matmul(m, n);
1194+
return true;
1195+
#else
1196+
return false;
1197+
#endif
1198+
}
1199+
11621200
default:
11631201
return false;
11641202
}

0 commit comments

Comments
 (0)