Skip to content

Commit a2c42f9

Browse files
ikawrakowKawrakow
andauthored
Faster IQ3_KT and IQ4_KT (#453)
* Somewhat faster iq3_kt (AVX2) * Cleanup * Slightly faster iq4_kt * Slightly faster iq4_kt PP is now almost 50% better than original, TG is ~20% better * Cleanup * Very slightly faster iq4_kt TG --------- Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent 9fb82af commit a2c42f9

File tree

1 file changed

+83
-93
lines changed

1 file changed

+83
-93
lines changed

ggml/src/iqk/iqk_gemm_ktquants.cpp

Lines changed: 83 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,8 @@ struct Trellis2 {
8989
const __m256i mask2 = _mm256_set1_epi32(km32);
9090

9191
inline __m256i next8(uint32_t val1, uint32_t val2) {
92-
__m256i mval = _mm256_setr_epi32(val1, val1, val1, val1, val2, val2, val2, val2);
92+
__m256i mval = MM256_SET_M128I(_mm_set1_epi32(val2), _mm_set1_epi32(val1));
93+
//__m256i mval = _mm256_setr_epi32(val1, val1, val1, val1, val2, val2, val2, val2);
9394
__m256i mres = _mm256_add_epi32(_mm256_mullo_epi32(mval, mka), mkb);
9495
return _mm256_xor_si256(_mm256_and_si256(mres, _mm256_set1_epi32(kmask)), _mm256_set1_epi32(km32));
9596
}
@@ -163,35 +164,21 @@ static inline __m256 abs_ps(__m256 vals) {
163164
return _mm256_andnot_ps(sign_bit, vals);
164165
}
165166

166-
// Negates 32-bit float lanes of an 8x32-bit vector
167-
// based on 8x8-bit condition var. For float lane i, if byte i of
168-
// `condition` is nonzero, the float will be negated.
169-
static inline __m256 conditional_negate_ps(__m256 vals, uint64_t condition_mask_u64) {
170-
__m128i condition_bytes = _mm_set_epi64x(0, condition_mask_u64);
171-
// Make `should_negate_byte_mask` where byte i == 0xFF if byte i in condition_bytes is zero,
172-
// else 0x00 (upper bytes are meaningless)
173-
__m128i zeros = _mm_setzero_si128();
174-
__m128i is_zero_byte_mask = _mm_cmpeq_epi8(condition_bytes, zeros);
175-
__m128i should_negate_byte_mask = _mm_cmpeq_epi8(is_zero_byte_mask, zeros);
176-
// Widen lower 8x8 bits of `should_negate_byte_mask` to 8x32 bits by padding zeros
177-
// expanded_mask_epi32[j] will be 0x000000FF if vals[j] should be negated, zero otherwise
178-
__m256i expanded_mask_epi32 = _mm256_cvtepu8_epi32(should_negate_byte_mask);
179-
// Same as above but with all 32 bits of lane j set if vals[j] should be negated (use to make XOR mask)
180-
__m256i full_dword_negate_mask = _mm256_cmpgt_epi32(expanded_mask_epi32, _mm256_setzero_si256());
181-
// Negate via XOR on sign bits of each 32-bit float
182-
__m256i sign_bit_pattern = _mm256_set1_epi32(0x80000000); // MSB set for a 32-bit value
183-
__m256i xor_mask_epi32 = _mm256_and_si256(full_dword_negate_mask, sign_bit_pattern);
184-
__m256 xor_mask_ps = _mm256_castsi256_ps(xor_mask_epi32);
185-
return _mm256_xor_ps(vals, xor_mask_ps);
186-
}
187-
188167
template <int nrc_y>
189168
static void mul_mat_iq3_kt_F32_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
190169
assert(n%QK_K == 0);
191170
const int nb = n/QK_K;
192171

193172
Trellis1 trellis;
194173

174+
union { __m256 vec; float val[8]; } s_helper;
175+
176+
auto shifts = _mm_set_epi32(0, 0, 4, 0);
177+
178+
__m256i all_signs[4];
179+
auto mask1 = _mm256_set1_epi32(0x01);
180+
auto mask2 = _mm256_set1_epi32(0x10);
181+
195182
__m256 accd[nrc_y];
196183
const float * y[nrc_y];
197184
for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const float *)info.src1_row(iy);
@@ -206,31 +193,28 @@ static void mul_mat_iq3_kt_F32_T(int n, const void * vx, size_t bx, const DataIn
206193
for (int i = 0; i < nb; ++i) {
207194
const uint16_t * ql = (const uint16_t *)x[i].ql;
208195
const uint8_t * qh = x[i].qh;
209-
for (int j = 0; j < 128; j+=8) {
210-
uint64_t mask1 = 0x0101010101010101 << (j/32);
211-
uint64_t mask2 = mask1 << 4;
212-
uint32_t val1 = ql[j/8] + 4096;
213-
uint32_t val2 = ql[j/8+16] + 4096;
214-
const uint64_t signs = *((const uint64_t *)(qh + (j%32)));
215-
const float x_scale1 = (x[i].scales[j/32] & 0xf);
216-
const float x_scale2 = (x[i].scales[j/32] >> 4);
217-
const __m256 x_val1 = abs_ps(trellis_gen8(trellis.next8(val1)));
218-
const __m256 x_val2 = abs_ps(trellis_gen8(trellis.next8(val2)));
219-
for (int iy = 0; iy < nrc_y; ++iy) {
220-
accd[iy] = _mm256_fmadd_ps(
221-
conditional_negate_ps(
222-
_mm256_load_ps(y[iy] + i*QK_K+j), signs & mask1
223-
),
224-
_mm256_mul_ps(_mm256_set1_ps(x_scale1), x_val1),
225-
accd[iy]
226-
);
227-
accd[iy] = _mm256_fmadd_ps(
228-
conditional_negate_ps(
229-
_mm256_load_ps(y[iy] + i*QK_K+j+128), signs & mask2
230-
),
231-
_mm256_mul_ps(_mm256_set1_ps(x_scale2), x_val2),
232-
accd[iy]
233-
);
196+
auto s8 = _mm_set1_epi32(*(const uint32_t *)x[i].scales);
197+
s8 = _mm_and_si128(_mm_srlv_epi32(s8, shifts), _mm_set1_epi8(0xf));
198+
auto s32 = _mm256_cvtepi8_epi32(s8);
199+
s_helper.vec = _mm256_cvtepi32_ps(s32);
200+
for (int j = 0; j < 4; ++j) all_signs[j] = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i *)(qh + 8*j)));
201+
for (int ib = 0; ib < 4; ++ib) {
202+
auto scale1 = _mm256_set1_ps(s_helper.val[ib+0]);
203+
auto scale2 = _mm256_set1_ps(s_helper.val[ib+4]);
204+
for (int j = 0; j < 4; ++j) {
205+
uint32_t val1 = ql[4*ib+j ] + 4096;
206+
uint32_t val2 = ql[4*ib+j+16] + 4096;
207+
auto sign1 = _mm256_and_si256(_mm256_cmpeq_epi32(_mm256_and_si256(all_signs[j], mask1), mask1), _mm256_set1_epi32(0x80000000));
208+
auto sign2 = _mm256_and_si256(_mm256_cmpeq_epi32(_mm256_and_si256(all_signs[j], mask2), mask2), _mm256_set1_epi32(0x80000000));
209+
all_signs[j] = _mm256_srli_epi32(all_signs[j], 1);
210+
auto x_val1 = abs_ps(trellis_gen8(trellis.next8(val1)));
211+
auto x_val2 = abs_ps(trellis_gen8(trellis.next8(val2)));
212+
x_val1 = _mm256_mul_ps(scale1, _mm256_xor_ps(x_val1, _mm256_castsi256_ps(sign1)));
213+
x_val2 = _mm256_mul_ps(scale2, _mm256_xor_ps(x_val2, _mm256_castsi256_ps(sign2)));
214+
for (int iy = 0; iy < nrc_y; ++iy) {
215+
accd[iy] = _mm256_fmadd_ps(_mm256_load_ps(y[iy] + i*QK_K+32*ib+8*j ), x_val1, accd[iy]);
216+
accd[iy] = _mm256_fmadd_ps(_mm256_load_ps(y[iy] + i*QK_K+32*ib+8*j+128), x_val2, accd[iy]);
217+
}
234218
}
235219
}
236220
}
@@ -250,66 +234,72 @@ static void mul_mat_iq4_kt_F32_T(int n, const void * vx, size_t bx, const DataIn
250234

251235
Trellis2 trellis;
252236

253-
__m256 accd[nrc_y];
254-
__m256 accd2[nrc_y];
237+
union { __m256 vec; float val[8]; } s_helper;
238+
union { __m256i vec; uint32_t val[8]; } o_helper;
239+
240+
constexpr int k_acc = nrc_y == 1 ? 2 : nrc_y;
241+
242+
__m256 accd[k_acc];
255243
const float * y[nrc_y];
256-
for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const float *)info.src1_row(iy);
244+
float row_sum[nrc_y];
245+
for (int iy = 0; iy < nrc_y; ++iy) {
246+
y[iy] = (const float *)info.src1_row(iy);
247+
auto sum = _mm256_setzero_ps();
248+
for (int i = 0; i < n/8; ++i) sum = _mm256_add_ps(sum, _mm256_loadu_ps(y[iy] + 8*i));
249+
row_sum[iy] = hsum_float_8(sum);
250+
}
257251

258252
for (int ix = 0; ix < nrc_x; ++ix) {
259253
const float * dptr = (const float *)((const char*)vx + ix*bx);
260-
const float d = dptr[0] * 31.75f * 1.01f;
261-
const float row_av = dptr[1];
254+
auto d = _mm256_set1_ps(dptr[0] * 31.75f * 1.01f);
255+
auto dav = dptr[1];
262256
const block_iq4_kt * x = (const block_iq4_kt *)(dptr + 2);
263257

264-
for (int iy = 0; iy < nrc_y; ++iy) {
265-
accd[iy] = _mm256_setzero_ps();
266-
accd2[iy] = _mm256_setzero_ps();
267-
}
258+
for (int iy = 0; iy < k_acc; ++iy) accd[iy] = _mm256_setzero_ps();
268259

269260
for (int i = 0; i < nb; ++i) {
261+
auto vshb = _mm256_loadu_si256((const __m256i *)x[i].qs);
270262
const uint32_t * shb = x[i].qs;
271263
const uint8_t * ql = (const uint8_t *)(shb + 8);
272264
const uint8_t * qh = ql + kNumGroups;
273-
for (int j = 0; j < 128; j+=8) {
274-
const uint32_t offset1 = 4096 + ((shb[j/32+0] & 1) << 15);
275-
const uint32_t offset2 = 4096 + ((shb[j/32+4] & 1) << 15);
276-
const float x_scale1 = (int)((shb[j/32+0] & 0xff) >> 1) - 64;
277-
const float x_scale2 = (int)((shb[j/32+4] & 0xff) >> 1) - 64;
278-
const uint32_t sh1 = shb[j/32+0] >> (8 + 6*((j/8)%4));
279-
const uint32_t sh2 = shb[j/32+4] >> (8 + 6*((j/8)%4));
280-
uint32_t val1 = ql[j/4+ 0] + ((qh[j/4+0] << 8) & 0xf00) + ((sh1 & 7) << 12) + offset1;
281-
uint32_t val2 = ql[j/4+32] + ((qh[j/4+0] << 4) & 0xf00) + ((sh2 & 7) << 12) + offset2;
282-
uint32_t val3 = ql[j/4+ 1] + ((qh[j/4+1] << 8) & 0xf00) + ((sh1 & 56) << 9) + offset1;
283-
uint32_t val4 = ql[j/4+33] + ((qh[j/4+1] << 4) & 0xf00) + ((sh2 & 56) << 9) + offset2;
284-
const __m256 x_val1 = trellis_gen8(trellis.next8(val1, val3));
285-
const __m256 x_val2 = trellis_gen8(trellis.next8(val2, val4));
286-
for (int iy = 0; iy < nrc_y; ++iy) {
287-
accd[iy] = _mm256_fmadd_ps(
288-
_mm256_load_ps(y[iy] + i*QK_K+j),
289-
_mm256_mul_ps(_mm256_set1_ps(x_scale1), x_val1),
290-
accd[iy]
291-
);
292-
accd[iy] = _mm256_fmadd_ps(
293-
_mm256_load_ps(y[iy] + i*QK_K+j+128),
294-
_mm256_mul_ps(_mm256_set1_ps(x_scale2), x_val2),
295-
accd[iy]
296-
);
297-
accd2[iy] = _mm256_add_ps(
298-
_mm256_load_ps(y[iy] + i*QK_K+j),
299-
accd2[iy]
300-
);
301-
accd2[iy] = _mm256_add_ps(
302-
_mm256_load_ps(y[iy] + i*QK_K+j+128),
303-
accd2[iy]
304-
);
265+
auto iscales = _mm256_srli_epi32(_mm256_and_si256(vshb, _mm256_set1_epi32(0xff)), 1);
266+
s_helper.vec = _mm256_mul_ps(d, _mm256_cvtepi32_ps(_mm256_sub_epi32(iscales, _mm256_set1_epi32(64))));
267+
o_helper.vec = _mm256_add_epi32(_mm256_slli_epi32(_mm256_and_si256(vshb, _mm256_set1_epi32(1)), 15), _mm256_set1_epi32(4096));
268+
for (int ib = 0; ib < 4; ++ib) {
269+
auto scale1 = _mm256_set1_ps(s_helper.val[ib+0]);
270+
auto scale2 = _mm256_set1_ps(s_helper.val[ib+4]);
271+
for (int j = 0; j < 4; ++j) {
272+
const uint32_t sh1 = shb[ib+0] >> (8 + 6*j);
273+
const uint32_t sh2 = shb[ib+4] >> (8 + 6*j);
274+
uint32_t val1 = ql[8*ib+2*j+ 0] + ((qh[8*ib+2*j+0] << 8) & 0xf00) + ((sh1 & 7) << 12) + o_helper.val[ib+0];
275+
uint32_t val2 = ql[8*ib+2*j+32] + ((qh[8*ib+2*j+0] << 4) & 0xf00) + ((sh2 & 7) << 12) + o_helper.val[ib+4];
276+
uint32_t val3 = ql[8*ib+2*j+ 1] + ((qh[8*ib+2*j+1] << 8) & 0xf00) + ((sh1 & 56) << 9) + o_helper.val[ib+0];
277+
uint32_t val4 = ql[8*ib+2*j+33] + ((qh[8*ib+2*j+1] << 4) & 0xf00) + ((sh2 & 56) << 9) + o_helper.val[ib+4];
278+
auto x_val1 = _mm256_mul_ps(scale1, trellis_gen8(trellis.next8(val1, val3)));
279+
auto x_val2 = _mm256_mul_ps(scale2, trellis_gen8(trellis.next8(val2, val4)));
280+
if constexpr (nrc_y == 1) {
281+
auto y1 = _mm256_load_ps(y[0] + i*QK_K+32*ib+8*j+ 0);
282+
auto y2 = _mm256_load_ps(y[0] + i*QK_K+32*ib+8*j+128);
283+
accd[0] = _mm256_fmadd_ps(y1, x_val1, accd[0]);
284+
accd[1] = _mm256_fmadd_ps(y2, x_val2, accd[1]);
285+
} else {
286+
for (int iy = 0; iy < nrc_y; ++iy) {
287+
auto y1 = _mm256_load_ps(y[iy] + i*QK_K+32*ib+8*j+ 0);
288+
auto y2 = _mm256_load_ps(y[iy] + i*QK_K+32*ib+8*j+128);
289+
accd[iy] = _mm256_fmadd_ps(y1, x_val1, accd[iy]);
290+
accd[iy] = _mm256_fmadd_ps(y2, x_val2, accd[iy]);
291+
}
292+
}
305293
}
306294
}
307295
}
308296

309-
for (int iy = 0; iy < nrc_y; ++iy) {
310-
__m256 res = _mm256_mul_ps(_mm256_set1_ps(d), accd[iy]);
311-
__m256 res2 = _mm256_mul_ps(_mm256_set1_ps(row_av), accd2[iy]);
312-
info.store(ix, iy, hsum_float_8(res) + hsum_float_8(res2));
297+
if constexpr (nrc_y == 1) {
298+
info.store(ix, 0, hsum_float_8(_mm256_add_ps(accd[0], accd[1])) + dav*row_sum[0]);
299+
} else {
300+
for (int iy = 0; iy < nrc_y; ++iy) {
301+
info.store(ix, iy, hsum_float_8(accd[iy]) + dav*row_sum[iy]);
302+
}
313303
}
314304
}
315305
}

0 commit comments

Comments
 (0)