@@ -509,14 +509,25 @@ static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
509
509
const __m256i ax = _mm256_sign_epi8 (x , x );
510
510
// Sign the values of the y vectors
511
511
const __m256i sy = _mm256_sign_epi8 (y , x );
512
+ #if __AVXVNNI__
513
+ const __m256i zero = _mm256_setzero_si256 ();
514
+ const __m256i summed_pairs = _mm256_dpbusd_epi32 (zero , ax , sy );
515
+ return _mm256_cvtepi32_ps (summed_pairs );
516
+ #else
512
517
// Perform multiplication and create 16-bit values
513
518
const __m256i dot = _mm256_maddubs_epi16 (ax , sy );
514
519
return sum_i16_pairs_float (dot );
520
+ #endif
515
521
}
516
522
517
523
static inline __m128i packNibbles ( __m256i bytes )
518
524
{
519
525
// Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
526
+ #if __AVX512F__
527
+ const __m256i bytes_srli_4 = _mm256_srli_epi16 (bytes , 4 ); // 0000_0000_abcd_0000
528
+ bytes = _mm256_or_si256 (bytes , bytes_srli_4 ); // 0000_abcd_abcd_efgh
529
+ return _mm256_cvtepi16_epi8 (bytes ); // abcd_efgh
530
+ #else
520
531
const __m256i lowByte = _mm256_set1_epi16 ( 0xFF );
521
532
__m256i high = _mm256_andnot_si256 ( lowByte , bytes );
522
533
__m256i low = _mm256_and_si256 ( lowByte , bytes );
@@ -527,6 +538,7 @@ static inline __m128i packNibbles( __m256i bytes )
527
538
__m128i r0 = _mm256_castsi256_si128 ( bytes );
528
539
__m128i r1 = _mm256_extracti128_si256 ( bytes , 1 );
529
540
return _mm_packus_epi16 ( r0 , r1 );
541
+ #endif
530
542
}
531
543
#else
532
544
static inline __m128i packNibbles ( __m128i bytes1 , __m128i bytes2 )
0 commit comments