@@ -2504,22 +2504,18 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
2504
2504
for (int rp = 0 ; rp < 4 ; rp ++ ) {
2505
2505
// Load the four block_q4_0 quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3
2506
2506
// Loaded as set of 128 bit vectors and repeated into a 256 bit vector
2507
- __m256i lhs_mat_01_0 = _mm256_castsi128_si256 (_mm_loadu_si128 ((const __m128i * )(a_ptrs [rp ][b ].qs )));
2508
- lhs_mat_01_0 = _mm256_permute2f128_si256 (lhs_mat_01_0 , lhs_mat_01_0 , 0 );
2509
- __m256i lhs_mat_23_0 = _mm256_castsi128_si256 (_mm_loadu_si128 ((const __m128i * )(a_ptrs [rp ][b ].qs + 16 )));
2510
- lhs_mat_23_0 = _mm256_permute2f128_si256 (lhs_mat_23_0 , lhs_mat_23_0 , 0 );
2511
- __m256i lhs_mat_01_1 = _mm256_castsi128_si256 (_mm_loadu_si128 ((const __m128i * )(a_ptrs [rp ][b ].qs + 32 )));
2512
- lhs_mat_01_1 = _mm256_permute2f128_si256 (lhs_mat_01_1 , lhs_mat_01_1 , 0 );
2513
- __m256i lhs_mat_23_1 = _mm256_castsi128_si256 (_mm_loadu_si128 ((const __m128i * )(a_ptrs [rp ][b ].qs + 48 )));
2514
- lhs_mat_23_1 = _mm256_permute2f128_si256 (lhs_mat_23_1 , lhs_mat_23_1 , 0 );
2515
- __m256i lhs_mat_01_2 = _mm256_castsi128_si256 (_mm_loadu_si128 ((const __m128i * )(a_ptrs [rp ][b ].qs + 64 )));
2516
- lhs_mat_01_2 = _mm256_permute2f128_si256 (lhs_mat_01_2 , lhs_mat_01_2 , 0 );
2517
- __m256i lhs_mat_23_2 = _mm256_castsi128_si256 (_mm_loadu_si128 ((const __m128i * )(a_ptrs [rp ][b ].qs + 80 )));
2518
- lhs_mat_23_2 = _mm256_permute2f128_si256 (lhs_mat_23_2 , lhs_mat_23_2 , 0 );
2519
- __m256i lhs_mat_01_3 = _mm256_castsi128_si256 (_mm_loadu_si128 ((const __m128i * )(a_ptrs [rp ][b ].qs + 96 )));
2520
- lhs_mat_01_3 = _mm256_permute2f128_si256 (lhs_mat_01_3 , lhs_mat_01_3 , 0 );
2521
- __m256i lhs_mat_23_3 = _mm256_castsi128_si256 (_mm_loadu_si128 ((const __m128i * )(a_ptrs [rp ][b ].qs + 112 )));
2522
- lhs_mat_23_3 = _mm256_permute2f128_si256 (lhs_mat_23_3 , lhs_mat_23_3 , 0 );
2507
+ __m256i lhs_mat_0123_0 = _mm256_loadu_si256 ((const __m256i * )((a_ptrs [rp ][b ].qs )));
2508
+ __m256i lhs_mat_01_0 = _mm256_permute2f128_si256 (lhs_mat_0123_0 , lhs_mat_0123_0 , 0 );
2509
+ __m256i lhs_mat_23_0 = _mm256_permute2f128_si256 (lhs_mat_0123_0 , lhs_mat_0123_0 , 17 );
2510
+ __m256i lhs_mat_0123_1 = _mm256_loadu_si256 ((const __m256i * )((a_ptrs [rp ][b ].qs + 32 )));
2511
+ __m256i lhs_mat_01_1 = _mm256_permute2f128_si256 (lhs_mat_0123_1 , lhs_mat_0123_1 , 0 );
2512
+ __m256i lhs_mat_23_1 = _mm256_permute2f128_si256 (lhs_mat_0123_1 , lhs_mat_0123_1 , 17 );
2513
+ __m256i lhs_mat_0123_2 = _mm256_loadu_si256 ((const __m256i * )((a_ptrs [rp ][b ].qs + 64 )));
2514
+ __m256i lhs_mat_01_2 = _mm256_permute2f128_si256 (lhs_mat_0123_2 , lhs_mat_0123_2 , 0 );
2515
+ __m256i lhs_mat_23_2 = _mm256_permute2f128_si256 (lhs_mat_0123_2 , lhs_mat_0123_2 , 17 );
2516
+ __m256i lhs_mat_0123_3 = _mm256_loadu_si256 ((const __m256i * )((a_ptrs [rp ][b ].qs + 96 )));
2517
+ __m256i lhs_mat_01_3 = _mm256_permute2f128_si256 (lhs_mat_0123_3 , lhs_mat_0123_3 , 0 );
2518
+ __m256i lhs_mat_23_3 = _mm256_permute2f128_si256 (lhs_mat_0123_3 , lhs_mat_0123_3 , 17 );
2523
2519
2524
2520
// Shuffle pattern one - left side input
2525
2521
const __m256i lhs_mat_01_0_sp1 = _mm256_shuffle_epi32 (lhs_mat_01_0 , 160 ); //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3)
@@ -2670,22 +2666,18 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
2670
2666
2671
2667
// Load the four block_q4_0 quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3
2672
2668
// Loaded as set of 128 bit vectors and repeated into a 256 bit vector
2673
- __m256i lhs_mat_01_0 = _mm256_castsi128_si256 (_mm_loadu_si128 ((const __m128i * )(a_ptr [b ].qs )));
2674
- lhs_mat_01_0 = _mm256_permute2f128_si256 (lhs_mat_01_0 , lhs_mat_01_0 , 0 );
2675
- __m256i lhs_mat_23_0 = _mm256_castsi128_si256 (_mm_loadu_si128 ((const __m128i * )(a_ptr [b ].qs + 16 )));
2676
- lhs_mat_23_0 = _mm256_permute2f128_si256 (lhs_mat_23_0 , lhs_mat_23_0 , 0 );
2677
- __m256i lhs_mat_01_1 = _mm256_castsi128_si256 (_mm_loadu_si128 ((const __m128i * )(a_ptr [b ].qs + 32 )));
2678
- lhs_mat_01_1 = _mm256_permute2f128_si256 (lhs_mat_01_1 , lhs_mat_01_1 , 0 );
2679
- __m256i lhs_mat_23_1 = _mm256_castsi128_si256 (_mm_loadu_si128 ((const __m128i * )(a_ptr [b ].qs + 48 )));
2680
- lhs_mat_23_1 = _mm256_permute2f128_si256 (lhs_mat_23_1 , lhs_mat_23_1 , 0 );
2681
- __m256i lhs_mat_01_2 = _mm256_castsi128_si256 (_mm_loadu_si128 ((const __m128i * )(a_ptr [b ].qs + 64 )));
2682
- lhs_mat_01_2 = _mm256_permute2f128_si256 (lhs_mat_01_2 , lhs_mat_01_2 , 0 );
2683
- __m256i lhs_mat_23_2 = _mm256_castsi128_si256 (_mm_loadu_si128 ((const __m128i * )(a_ptr [b ].qs + 80 )));
2684
- lhs_mat_23_2 = _mm256_permute2f128_si256 (lhs_mat_23_2 , lhs_mat_23_2 , 0 );
2685
- __m256i lhs_mat_01_3 = _mm256_castsi128_si256 (_mm_loadu_si128 ((const __m128i * )(a_ptr [b ].qs + 96 )));
2686
- lhs_mat_01_3 = _mm256_permute2f128_si256 (lhs_mat_01_3 , lhs_mat_01_3 , 0 );
2687
- __m256i lhs_mat_23_3 = _mm256_castsi128_si256 (_mm_loadu_si128 ((const __m128i * )(a_ptr [b ].qs + 112 )));
2688
- lhs_mat_23_3 = _mm256_permute2f128_si256 (lhs_mat_23_3 , lhs_mat_23_3 , 0 );
2669
+ __m256i lhs_mat_0123_0 = _mm256_loadu_si256 ((const __m256i * )((a_ptr [b ].qs )));
2670
+ __m256i lhs_mat_01_0 = _mm256_permute2f128_si256 (lhs_mat_0123_0 , lhs_mat_0123_0 , 0 );
2671
+ __m256i lhs_mat_23_0 = _mm256_permute2f128_si256 (lhs_mat_0123_0 , lhs_mat_0123_0 , 17 );
2672
+ __m256i lhs_mat_0123_1 = _mm256_loadu_si256 ((const __m256i * )((a_ptr [b ].qs + 32 )));
2673
+ __m256i lhs_mat_01_1 = _mm256_permute2f128_si256 (lhs_mat_0123_1 , lhs_mat_0123_1 , 0 );
2674
+ __m256i lhs_mat_23_1 = _mm256_permute2f128_si256 (lhs_mat_0123_1 , lhs_mat_0123_1 , 17 );
2675
+ __m256i lhs_mat_0123_2 = _mm256_loadu_si256 ((const __m256i * )((a_ptr [b ].qs + 64 )));
2676
+ __m256i lhs_mat_01_2 = _mm256_permute2f128_si256 (lhs_mat_0123_2 , lhs_mat_0123_2 , 0 );
2677
+ __m256i lhs_mat_23_2 = _mm256_permute2f128_si256 (lhs_mat_0123_2 , lhs_mat_0123_2 , 17 );
2678
+ __m256i lhs_mat_0123_3 = _mm256_loadu_si256 ((const __m256i * )((a_ptr [b ].qs + 96 )));
2679
+ __m256i lhs_mat_01_3 = _mm256_permute2f128_si256 (lhs_mat_0123_3 , lhs_mat_0123_3 , 0 );
2680
+ __m256i lhs_mat_23_3 = _mm256_permute2f128_si256 (lhs_mat_0123_3 , lhs_mat_0123_3 , 17 );
2689
2681
2690
2682
// Shuffle pattern one - left side input
2691
2683
0 commit comments