@@ -89,7 +89,8 @@ struct Trellis2 {
89
89
const __m256i mask2 = _mm256_set1_epi32(km32);
90
90
91
91
inline __m256i next8 (uint32_t val1, uint32_t val2) {
92
- __m256i mval = _mm256_setr_epi32 (val1, val1, val1, val1, val2, val2, val2, val2);
92
+ __m256i mval = MM256_SET_M128I (_mm_set1_epi32 (val2), _mm_set1_epi32 (val1));
93
+ // __m256i mval = _mm256_setr_epi32(val1, val1, val1, val1, val2, val2, val2, val2);
93
94
__m256i mres = _mm256_add_epi32 (_mm256_mullo_epi32 (mval, mka), mkb);
94
95
return _mm256_xor_si256 (_mm256_and_si256 (mres, _mm256_set1_epi32 (kmask)), _mm256_set1_epi32 (km32));
95
96
}
@@ -163,35 +164,21 @@ static inline __m256 abs_ps(__m256 vals) {
163
164
return _mm256_andnot_ps (sign_bit, vals);
164
165
}
165
166
166
- // Negates 32-bit float lanes of an 8x32-bit vector
167
- // based on 8x8-bit condition var. For float lane i, if byte i of
168
- // `condition` is nonzero, the float will be negated.
169
- static inline __m256 conditional_negate_ps (__m256 vals, uint64_t condition_mask_u64) {
170
- __m128i condition_bytes = _mm_set_epi64x (0 , condition_mask_u64);
171
- // Make `should_negate_byte_mask` where byte i == 0xFF if byte i in condition_bytes is zero,
172
- // else 0x00 (upper bytes are meaningless)
173
- __m128i zeros = _mm_setzero_si128 ();
174
- __m128i is_zero_byte_mask = _mm_cmpeq_epi8 (condition_bytes, zeros);
175
- __m128i should_negate_byte_mask = _mm_cmpeq_epi8 (is_zero_byte_mask, zeros);
176
- // Widen lower 8x8 bits of `should_negate_byte_mask` to 8x32 bits by padding zeros
177
- // expanded_mask_epi32[j] will be 0x000000FF if vals[j] should be negated, zero otherwise
178
- __m256i expanded_mask_epi32 = _mm256_cvtepu8_epi32 (should_negate_byte_mask);
179
- // Same as above but with all 32 bits of lane j set if vals[j] should be negated (use to make XOR mask)
180
- __m256i full_dword_negate_mask = _mm256_cmpgt_epi32 (expanded_mask_epi32, _mm256_setzero_si256 ());
181
- // Negate via XOR on sign bits of each 32-bit float
182
- __m256i sign_bit_pattern = _mm256_set1_epi32 (0x80000000 ); // MSB set for a 32-bit value
183
- __m256i xor_mask_epi32 = _mm256_and_si256 (full_dword_negate_mask, sign_bit_pattern);
184
- __m256 xor_mask_ps = _mm256_castsi256_ps (xor_mask_epi32);
185
- return _mm256_xor_ps (vals, xor_mask_ps);
186
- }
187
-
188
167
template <int nrc_y>
189
168
static void mul_mat_iq3_kt_F32_T (int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
190
169
assert (n%QK_K == 0 );
191
170
const int nb = n/QK_K;
192
171
193
172
Trellis1 trellis;
194
173
174
+ union { __m256 vec; float val[8 ]; } s_helper;
175
+
176
+ auto shifts = _mm_set_epi32 (0 , 0 , 4 , 0 );
177
+
178
+ __m256i all_signs[4 ];
179
+ auto mask1 = _mm256_set1_epi32 (0x01 );
180
+ auto mask2 = _mm256_set1_epi32 (0x10 );
181
+
195
182
__m256 accd[nrc_y];
196
183
const float * y[nrc_y];
197
184
for (int iy = 0 ; iy < nrc_y; ++iy) y[iy] = (const float *)info.src1_row (iy);
@@ -206,31 +193,28 @@ static void mul_mat_iq3_kt_F32_T(int n, const void * vx, size_t bx, const DataIn
206
193
for (int i = 0 ; i < nb; ++i) {
207
194
const uint16_t * ql = (const uint16_t *)x[i].ql ;
208
195
const uint8_t * qh = x[i].qh ;
209
- for (int j = 0 ; j < 128 ; j+=8 ) {
210
- uint64_t mask1 = 0x0101010101010101 << (j/32 );
211
- uint64_t mask2 = mask1 << 4 ;
212
- uint32_t val1 = ql[j/8 ] + 4096 ;
213
- uint32_t val2 = ql[j/8 +16 ] + 4096 ;
214
- const uint64_t signs = *((const uint64_t *)(qh + (j%32 )));
215
- const float x_scale1 = (x[i].scales [j/32 ] & 0xf );
216
- const float x_scale2 = (x[i].scales [j/32 ] >> 4 );
217
- const __m256 x_val1 = abs_ps (trellis_gen8 (trellis.next8 (val1)));
218
- const __m256 x_val2 = abs_ps (trellis_gen8 (trellis.next8 (val2)));
219
- for (int iy = 0 ; iy < nrc_y; ++iy) {
220
- accd[iy] = _mm256_fmadd_ps (
221
- conditional_negate_ps (
222
- _mm256_load_ps (y[iy] + i*QK_K+j), signs & mask1
223
- ),
224
- _mm256_mul_ps (_mm256_set1_ps (x_scale1), x_val1),
225
- accd[iy]
226
- );
227
- accd[iy] = _mm256_fmadd_ps (
228
- conditional_negate_ps (
229
- _mm256_load_ps (y[iy] + i*QK_K+j+128 ), signs & mask2
230
- ),
231
- _mm256_mul_ps (_mm256_set1_ps (x_scale2), x_val2),
232
- accd[iy]
233
- );
196
+ auto s8 = _mm_set1_epi32 (*(const uint32_t *)x[i].scales );
197
+ s8 = _mm_and_si128 (_mm_srlv_epi32 (s8, shifts), _mm_set1_epi8 (0xf ));
198
+ auto s32 = _mm256_cvtepi8_epi32 (s8);
199
+ s_helper.vec = _mm256_cvtepi32_ps (s32);
200
+ for (int j = 0 ; j < 4 ; ++j) all_signs[j] = _mm256_cvtepu8_epi32 (_mm_loadl_epi64 ((const __m128i *)(qh + 8 *j)));
201
+ for (int ib = 0 ; ib < 4 ; ++ib) {
202
+ auto scale1 = _mm256_set1_ps (s_helper.val [ib+0 ]);
203
+ auto scale2 = _mm256_set1_ps (s_helper.val [ib+4 ]);
204
+ for (int j = 0 ; j < 4 ; ++j) {
205
+ uint32_t val1 = ql[4 *ib+j ] + 4096 ;
206
+ uint32_t val2 = ql[4 *ib+j+16 ] + 4096 ;
207
+ auto sign1 = _mm256_and_si256 (_mm256_cmpeq_epi32 (_mm256_and_si256 (all_signs[j], mask1), mask1), _mm256_set1_epi32 (0x80000000 ));
208
+ auto sign2 = _mm256_and_si256 (_mm256_cmpeq_epi32 (_mm256_and_si256 (all_signs[j], mask2), mask2), _mm256_set1_epi32 (0x80000000 ));
209
+ all_signs[j] = _mm256_srli_epi32 (all_signs[j], 1 );
210
+ auto x_val1 = abs_ps (trellis_gen8 (trellis.next8 (val1)));
211
+ auto x_val2 = abs_ps (trellis_gen8 (trellis.next8 (val2)));
212
+ x_val1 = _mm256_mul_ps (scale1, _mm256_xor_ps (x_val1, _mm256_castsi256_ps (sign1)));
213
+ x_val2 = _mm256_mul_ps (scale2, _mm256_xor_ps (x_val2, _mm256_castsi256_ps (sign2)));
214
+ for (int iy = 0 ; iy < nrc_y; ++iy) {
215
+ accd[iy] = _mm256_fmadd_ps (_mm256_load_ps (y[iy] + i*QK_K+32 *ib+8 *j ), x_val1, accd[iy]);
216
+ accd[iy] = _mm256_fmadd_ps (_mm256_load_ps (y[iy] + i*QK_K+32 *ib+8 *j+128 ), x_val2, accd[iy]);
217
+ }
234
218
}
235
219
}
236
220
}
@@ -250,66 +234,72 @@ static void mul_mat_iq4_kt_F32_T(int n, const void * vx, size_t bx, const DataIn
250
234
251
235
Trellis2 trellis;
252
236
253
- __m256 accd[nrc_y];
254
- __m256 accd2[nrc_y];
237
+ union { __m256 vec; float val[8 ]; } s_helper;
238
+ union { __m256i vec; uint32_t val[8 ]; } o_helper;
239
+
240
+ constexpr int k_acc = nrc_y == 1 ? 2 : nrc_y;
241
+
242
+ __m256 accd[k_acc];
255
243
const float * y[nrc_y];
256
- for (int iy = 0 ; iy < nrc_y; ++iy) y[iy] = (const float *)info.src1_row (iy);
244
+ float row_sum[nrc_y];
245
+ for (int iy = 0 ; iy < nrc_y; ++iy) {
246
+ y[iy] = (const float *)info.src1_row (iy);
247
+ auto sum = _mm256_setzero_ps ();
248
+ for (int i = 0 ; i < n/8 ; ++i) sum = _mm256_add_ps (sum, _mm256_loadu_ps (y[iy] + 8 *i));
249
+ row_sum[iy] = hsum_float_8 (sum);
250
+ }
257
251
258
252
for (int ix = 0 ; ix < nrc_x; ++ix) {
259
253
const float * dptr = (const float *)((const char *)vx + ix*bx);
260
- const float d = dptr[0 ] * 31 .75f * 1 .01f ;
261
- const float row_av = dptr[1 ];
254
+ auto d = _mm256_set1_ps ( dptr[0 ] * 31 .75f * 1 .01f ) ;
255
+ auto dav = dptr[1 ];
262
256
const block_iq4_kt * x = (const block_iq4_kt *)(dptr + 2 );
263
257
264
- for (int iy = 0 ; iy < nrc_y; ++iy) {
265
- accd[iy] = _mm256_setzero_ps ();
266
- accd2[iy] = _mm256_setzero_ps ();
267
- }
258
+ for (int iy = 0 ; iy < k_acc; ++iy) accd[iy] = _mm256_setzero_ps ();
268
259
269
260
for (int i = 0 ; i < nb; ++i) {
261
+ auto vshb = _mm256_loadu_si256 ((const __m256i *)x[i].qs );
270
262
const uint32_t * shb = x[i].qs ;
271
263
const uint8_t * ql = (const uint8_t *)(shb + 8 );
272
264
const uint8_t * qh = ql + kNumGroups ;
273
- for (int j = 0 ; j < 128 ; j+=8 ) {
274
- const uint32_t offset1 = 4096 + ((shb[j/32 +0 ] & 1 ) << 15 );
275
- const uint32_t offset2 = 4096 + ((shb[j/32 +4 ] & 1 ) << 15 );
276
- const float x_scale1 = (int )((shb[j/32 +0 ] & 0xff ) >> 1 ) - 64 ;
277
- const float x_scale2 = (int )((shb[j/32 +4 ] & 0xff ) >> 1 ) - 64 ;
278
- const uint32_t sh1 = shb[j/32 +0 ] >> (8 + 6 *((j/8 )%4 ));
279
- const uint32_t sh2 = shb[j/32 +4 ] >> (8 + 6 *((j/8 )%4 ));
280
- uint32_t val1 = ql[j/4 + 0 ] + ((qh[j/4 +0 ] << 8 ) & 0xf00 ) + ((sh1 & 7 ) << 12 ) + offset1;
281
- uint32_t val2 = ql[j/4 +32 ] + ((qh[j/4 +0 ] << 4 ) & 0xf00 ) + ((sh2 & 7 ) << 12 ) + offset2;
282
- uint32_t val3 = ql[j/4 + 1 ] + ((qh[j/4 +1 ] << 8 ) & 0xf00 ) + ((sh1 & 56 ) << 9 ) + offset1;
283
- uint32_t val4 = ql[j/4 +33 ] + ((qh[j/4 +1 ] << 4 ) & 0xf00 ) + ((sh2 & 56 ) << 9 ) + offset2;
284
- const __m256 x_val1 = trellis_gen8 (trellis.next8 (val1, val3));
285
- const __m256 x_val2 = trellis_gen8 (trellis.next8 (val2, val4));
286
- for (int iy = 0 ; iy < nrc_y; ++iy) {
287
- accd[iy] = _mm256_fmadd_ps (
288
- _mm256_load_ps (y[iy] + i*QK_K+j),
289
- _mm256_mul_ps (_mm256_set1_ps (x_scale1), x_val1),
290
- accd[iy]
291
- );
292
- accd[iy] = _mm256_fmadd_ps (
293
- _mm256_load_ps (y[iy] + i*QK_K+j+128 ),
294
- _mm256_mul_ps (_mm256_set1_ps (x_scale2), x_val2),
295
- accd[iy]
296
- );
297
- accd2[iy] = _mm256_add_ps (
298
- _mm256_load_ps (y[iy] + i*QK_K+j),
299
- accd2[iy]
300
- );
301
- accd2[iy] = _mm256_add_ps (
302
- _mm256_load_ps (y[iy] + i*QK_K+j+128 ),
303
- accd2[iy]
304
- );
265
+ auto iscales = _mm256_srli_epi32 (_mm256_and_si256 (vshb, _mm256_set1_epi32 (0xff )), 1 );
266
+ s_helper.vec = _mm256_mul_ps (d, _mm256_cvtepi32_ps (_mm256_sub_epi32 (iscales, _mm256_set1_epi32 (64 ))));
267
+ o_helper.vec = _mm256_add_epi32 (_mm256_slli_epi32 (_mm256_and_si256 (vshb, _mm256_set1_epi32 (1 )), 15 ), _mm256_set1_epi32 (4096 ));
268
+ for (int ib = 0 ; ib < 4 ; ++ib) {
269
+ auto scale1 = _mm256_set1_ps (s_helper.val [ib+0 ]);
270
+ auto scale2 = _mm256_set1_ps (s_helper.val [ib+4 ]);
271
+ for (int j = 0 ; j < 4 ; ++j) {
272
+ const uint32_t sh1 = shb[ib+0 ] >> (8 + 6 *j);
273
+ const uint32_t sh2 = shb[ib+4 ] >> (8 + 6 *j);
274
+ uint32_t val1 = ql[8 *ib+2 *j+ 0 ] + ((qh[8 *ib+2 *j+0 ] << 8 ) & 0xf00 ) + ((sh1 & 7 ) << 12 ) + o_helper.val [ib+0 ];
275
+ uint32_t val2 = ql[8 *ib+2 *j+32 ] + ((qh[8 *ib+2 *j+0 ] << 4 ) & 0xf00 ) + ((sh2 & 7 ) << 12 ) + o_helper.val [ib+4 ];
276
+ uint32_t val3 = ql[8 *ib+2 *j+ 1 ] + ((qh[8 *ib+2 *j+1 ] << 8 ) & 0xf00 ) + ((sh1 & 56 ) << 9 ) + o_helper.val [ib+0 ];
277
+ uint32_t val4 = ql[8 *ib+2 *j+33 ] + ((qh[8 *ib+2 *j+1 ] << 4 ) & 0xf00 ) + ((sh2 & 56 ) << 9 ) + o_helper.val [ib+4 ];
278
+ auto x_val1 = _mm256_mul_ps (scale1, trellis_gen8 (trellis.next8 (val1, val3)));
279
+ auto x_val2 = _mm256_mul_ps (scale2, trellis_gen8 (trellis.next8 (val2, val4)));
280
+ if constexpr (nrc_y == 1 ) {
281
+ auto y1 = _mm256_load_ps (y[0 ] + i*QK_K+32 *ib+8 *j+ 0 );
282
+ auto y2 = _mm256_load_ps (y[0 ] + i*QK_K+32 *ib+8 *j+128 );
283
+ accd[0 ] = _mm256_fmadd_ps (y1, x_val1, accd[0 ]);
284
+ accd[1 ] = _mm256_fmadd_ps (y2, x_val2, accd[1 ]);
285
+ } else {
286
+ for (int iy = 0 ; iy < nrc_y; ++iy) {
287
+ auto y1 = _mm256_load_ps (y[iy] + i*QK_K+32 *ib+8 *j+ 0 );
288
+ auto y2 = _mm256_load_ps (y[iy] + i*QK_K+32 *ib+8 *j+128 );
289
+ accd[iy] = _mm256_fmadd_ps (y1, x_val1, accd[iy]);
290
+ accd[iy] = _mm256_fmadd_ps (y2, x_val2, accd[iy]);
291
+ }
292
+ }
305
293
}
306
294
}
307
295
}
308
296
309
- for (int iy = 0 ; iy < nrc_y; ++iy) {
310
- __m256 res = _mm256_mul_ps (_mm256_set1_ps (d), accd[iy]);
311
- __m256 res2 = _mm256_mul_ps (_mm256_set1_ps (row_av), accd2[iy]);
312
- info.store (ix, iy, hsum_float_8 (res) + hsum_float_8 (res2));
297
+ if constexpr (nrc_y == 1 ) {
298
+ info.store (ix, 0 , hsum_float_8 (_mm256_add_ps (accd[0 ], accd[1 ])) + dav*row_sum[0 ]);
299
+ } else {
300
+ for (int iy = 0 ; iy < nrc_y; ++iy) {
301
+ info.store (ix, iy, hsum_float_8 (accd[iy]) + dav*row_sum[iy]);
302
+ }
313
303
}
314
304
}
315
305
}
0 commit comments