@@ -11371,40 +11371,68 @@ void ggml_vec_dot_q1_3_q8_0(int n, float * restrict s, size_t bs, const void * r
11371
11371
__m256 accumf = _mm256_setzero_ps();
11372
11372
11373
11373
for (int i = 0; i < nb; ++i) {
11374
- {
11375
- __m256i x0 = _mm256_set_epi32(q1_3_grid[x[i].q[7]], q1_3_grid[x[i].q[6]],
11376
- q1_3_grid[x[i].q[5]], q1_3_grid[x[i].q[4]],
11377
- q1_3_grid[x[i].q[3]], q1_3_grid[x[i].q[2]],
11378
- q1_3_grid[x[i].q[1]], q1_3_grid[x[i].q[0]]);
11379
- __m256i y0 = _mm256_lddqu_si256((const __m256i_u *) (y[2*i].qs));
11380
-
11381
- __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(y[2*i].d));
11382
-
11383
- __m256 q = mul_sum_i8_pairs_float(x0, y0);
11384
-
11385
- accumf = _mm256_fmadd_ps(d, q, accumf);
11386
- }
11374
+ // __m128i x12b = _mm_maskload_epi32((const int32_t *) x[i].q, _mm_set_epi32(0, -1, -1, -1));
11375
+ // __m128i x12b = _mm_insert_epi8(x12a, x[i].qs[0], 12);
11376
+ // WARNING: reading 3 bytes further than necessary. It's faster than the above on my CPU, though.
11377
+ __m128i x12b = _mm_loadu_si128((const __m128i_u *) x[i].q);
11378
+ __m256i x12 = MM256_SET_M128I(x12b, x12b);
11387
11379
11388
11380
{
11389
- __m256i x1 = _mm256_castsi128_si256(_mm_set_epi32(q1_3_grid[x[i].q[11]], q1_3_grid[x[i].q[10]],
11390
- q1_3_grid[x[i].q[9]], q1_3_grid[x[i].q[8]]));
11391
- __m256i x2 = _mm256_cvtepu8_epi16(_mm_maskload_epi32((const int32_t *) x[i].q, _mm_set_epi32(0, -1, -1, -1)));
11381
+ __m256i x0l = _mm256_shuffle_epi8(x12, _mm256_set_epi8(5, -1, 5, -1, 5, -1, 5, -1,
11382
+ 4, -1, 4, -1, 4, -1, 4, -1,
11383
+ 1, -1, 1, -1, 1, -1, 1, -1,
11384
+ 0, -1, 0, -1, 0, -1, 0, -1));
11385
+ __m256i x0h = _mm256_shuffle_epi8(x12, _mm256_set_epi8(7, -1, 7, -1, 7, -1, 7, -1,
11386
+ 6, -1, 6, -1, 6, -1, 6, -1,
11387
+ 3, -1, 3, -1, 3, -1, 3, -1,
11388
+ 2, -1, 2, -1, 2, -1, 2, -1));
11389
+ __m256i x1l = _mm256_shuffle_epi8(x12, _mm256_set_epi8(7, -1, 6, -1, 5, -1, 4, -1,
11390
+ 3, -1, 2, -1, 1, -1, 0, -1,
11391
+ 9, -1, 9, -1, 9, -1, 9, -1,
11392
+ 8, -1, 8, -1, 8, -1, 8, -1));
11393
+ __m256i x1h = _mm256_shuffle_epi8(x12, _mm256_set_epi8(12, -1, 12, -1, 12, -1, 12, -1,
11394
+ 11, -1, 10, -1, 9, -1, 8, -1,
11395
+ 11, -1, 11, -1, 11, -1, 11, -1,
11396
+ 10, -1, 10, -1, 10, -1, 10, -1));
11397
+ const __m256i shift0 = _mm256_set_epi16(3, 9, 27, 81,
11398
+ 3, 9, 27, 81,
11399
+ 3, 9, 27, 81,
11400
+ 3, 9, 27, 81);
11401
+ const __m256i shift1l = _mm256_set_epi16(1, 1, 1, 1,
11402
+ 1, 1, 1, 1,
11403
+ 3, 9, 27, 81,
11404
+ 3, 9, 27, 81);
11405
+ const __m256i shift1h = _mm256_set_epi16(3, 9, 27, 81,
11406
+ 1, 1, 1, 1,
11407
+ 3, 9, 27, 81,
11408
+ 3, 9, 27, 81);
11409
+ x0l = _mm256_mullo_epi16(x0l, shift0);
11410
+ x0h = _mm256_mullo_epi16(x0h, shift0);
11411
+ x1l = _mm256_mullo_epi16(x1l, shift1l);
11412
+ x1h = _mm256_mullo_epi16(x1h, shift1h);
11413
+ x0l = _mm256_mulhi_epu16(x0l, _mm256_set1_epi16(3));
11414
+ x0h = _mm256_mulhi_epu16(x0h, _mm256_set1_epi16(3));
11415
+ x1l = _mm256_mulhi_epu16(x1l, _mm256_set1_epi16(3));
11416
+ x1h = _mm256_mulhi_epu16(x1h, _mm256_set1_epi16(3));
11417
+ x0l = _mm256_sub_epi16(x0l, _mm256_set1_epi16(1));
11418
+ x0h = _mm256_sub_epi16(x0h, _mm256_set1_epi16(1));
11419
+ x1l = _mm256_sub_epi16(x1l, _mm256_set1_epi16(1));
11420
+ x1h = _mm256_sub_epi16(x1h, _mm256_set1_epi16(1));
11421
+
11422
+ __m256i x0 = _mm256_packs_epi16(x0l, x0h);
11423
+ __m256i x1 = _mm256_packs_epi16(x1l, x1h);
11424
+
11425
+ __m256i y0 = _mm256_lddqu_si256((const __m256i_u *) (y[2*i + 0].qs));
11392
11426
__m256i y1 = _mm256_lddqu_si256((const __m256i_u *) (y[2*i + 1].qs));
11393
11427
11394
- x2 = _mm256_mulhi_epu16(x2, _mm256_set1_epi16(3 << 8));
11395
- x2 = _mm256_sub_epi16(x2, _mm256_set1_epi16(1));
11396
-
11397
- // TODO: reduce shuffling
11398
- x2 = _mm256_packs_epi16(x2, _mm256_setzero_si256());
11399
- x2 = _mm256_permute4x64_epi64(x2, _MM_SHUFFLE(3, 1, 2, 0));
11400
- __m128i x2_l = _mm_insert_epi32(_mm256_castsi256_si128(x2), q1_3_grid[x[i].qs[0]], 3);
11401
- x1 = _mm256_inserti128_si256(x1, x2_l, 1);
11402
-
11403
- __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(y[2*i + 1].d));
11428
+ __m256 d0 = _mm256_set1_ps(GGML_FP16_TO_FP32(y[2*i].d));
11429
+ __m256 d1 = _mm256_set1_ps(GGML_FP16_TO_FP32(y[2*i + 1].d));
11404
11430
11405
- __m256 q = mul_sum_i8_pairs_float(x1, y1);
11431
+ __m256 q0 = mul_sum_i8_pairs_float(x0, y0);
11432
+ __m256 q1 = mul_sum_i8_pairs_float(x1, y1);
11406
11433
11407
- accumf = _mm256_fmadd_ps(d, q, accumf);
11434
+ accumf = _mm256_fmadd_ps(d0, q0, accumf);
11435
+ accumf = _mm256_fmadd_ps(d1, q1, accumf);
11408
11436
}
11409
11437
}
11410
11438
0 commit comments