Skip to content

Commit c950fc3

Browse files
Srihari-mcwSrihari-mcw
authored andcommitted
Make updates to reduce number of load instructions
1 parent 364dc96 commit c950fc3

File tree

1 file changed

+24
-32
lines changed

1 file changed

+24
-32
lines changed

ggml/src/ggml-aarch64.c

Lines changed: 24 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -2504,22 +2504,18 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
25042504
for (int rp = 0; rp < 4; rp++) {
25052505
// Load the four block_q4_0 quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3
25062506
// Loaded as set of 128 bit vectors and repeated into a 256 bit vector
2507-
__m256i lhs_mat_01_0 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptrs[rp][b].qs)));
2508-
lhs_mat_01_0 = _mm256_permute2f128_si256(lhs_mat_01_0, lhs_mat_01_0, 0);
2509-
__m256i lhs_mat_23_0 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptrs[rp][b].qs + 16)));
2510-
lhs_mat_23_0 = _mm256_permute2f128_si256(lhs_mat_23_0, lhs_mat_23_0, 0);
2511-
__m256i lhs_mat_01_1 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptrs[rp][b].qs + 32)));
2512-
lhs_mat_01_1 = _mm256_permute2f128_si256(lhs_mat_01_1, lhs_mat_01_1, 0);
2513-
__m256i lhs_mat_23_1 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptrs[rp][b].qs + 48)));
2514-
lhs_mat_23_1 = _mm256_permute2f128_si256(lhs_mat_23_1, lhs_mat_23_1, 0);
2515-
__m256i lhs_mat_01_2 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptrs[rp][b].qs + 64)));
2516-
lhs_mat_01_2 = _mm256_permute2f128_si256(lhs_mat_01_2, lhs_mat_01_2, 0);
2517-
__m256i lhs_mat_23_2 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptrs[rp][b].qs + 80)));
2518-
lhs_mat_23_2 = _mm256_permute2f128_si256(lhs_mat_23_2, lhs_mat_23_2, 0);
2519-
__m256i lhs_mat_01_3 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptrs[rp][b].qs + 96)));
2520-
lhs_mat_01_3 = _mm256_permute2f128_si256(lhs_mat_01_3, lhs_mat_01_3, 0);
2521-
__m256i lhs_mat_23_3 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptrs[rp][b].qs + 112)));
2522-
lhs_mat_23_3 = _mm256_permute2f128_si256(lhs_mat_23_3, lhs_mat_23_3, 0);
2507+
__m256i lhs_mat_0123_0 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs)));
2508+
__m256i lhs_mat_01_0 = _mm256_permute2f128_si256(lhs_mat_0123_0, lhs_mat_0123_0, 0);
2509+
__m256i lhs_mat_23_0 = _mm256_permute2f128_si256(lhs_mat_0123_0, lhs_mat_0123_0, 17);
2510+
__m256i lhs_mat_0123_1 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 32)));
2511+
__m256i lhs_mat_01_1 = _mm256_permute2f128_si256(lhs_mat_0123_1, lhs_mat_0123_1, 0);
2512+
__m256i lhs_mat_23_1 = _mm256_permute2f128_si256(lhs_mat_0123_1, lhs_mat_0123_1, 17);
2513+
__m256i lhs_mat_0123_2 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 64)));
2514+
__m256i lhs_mat_01_2 = _mm256_permute2f128_si256(lhs_mat_0123_2, lhs_mat_0123_2, 0);
2515+
__m256i lhs_mat_23_2 = _mm256_permute2f128_si256(lhs_mat_0123_2, lhs_mat_0123_2, 17);
2516+
__m256i lhs_mat_0123_3 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 96)));
2517+
__m256i lhs_mat_01_3 = _mm256_permute2f128_si256(lhs_mat_0123_3, lhs_mat_0123_3, 0);
2518+
__m256i lhs_mat_23_3 = _mm256_permute2f128_si256(lhs_mat_0123_3, lhs_mat_0123_3, 17);
25232519

25242520
// Shuffle pattern one - left side input
25252521
const __m256i lhs_mat_01_0_sp1 = _mm256_shuffle_epi32(lhs_mat_01_0, 160); //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3)
@@ -2670,22 +2666,18 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
26702666

26712667
// Load the four block_q4_0 quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3
26722668
// Loaded as set of 128 bit vectors and repeated into a 256 bit vector
2673-
__m256i lhs_mat_01_0 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs)));
2674-
lhs_mat_01_0 = _mm256_permute2f128_si256(lhs_mat_01_0, lhs_mat_01_0, 0);
2675-
__m256i lhs_mat_23_0 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 16)));
2676-
lhs_mat_23_0 = _mm256_permute2f128_si256(lhs_mat_23_0, lhs_mat_23_0, 0);
2677-
__m256i lhs_mat_01_1 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 32)));
2678-
lhs_mat_01_1 = _mm256_permute2f128_si256(lhs_mat_01_1, lhs_mat_01_1, 0);
2679-
__m256i lhs_mat_23_1 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 48)));
2680-
lhs_mat_23_1 = _mm256_permute2f128_si256(lhs_mat_23_1, lhs_mat_23_1, 0);
2681-
__m256i lhs_mat_01_2 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 64)));
2682-
lhs_mat_01_2 = _mm256_permute2f128_si256(lhs_mat_01_2, lhs_mat_01_2, 0);
2683-
__m256i lhs_mat_23_2 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 80)));
2684-
lhs_mat_23_2 = _mm256_permute2f128_si256(lhs_mat_23_2, lhs_mat_23_2, 0);
2685-
__m256i lhs_mat_01_3 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 96)));
2686-
lhs_mat_01_3 = _mm256_permute2f128_si256(lhs_mat_01_3, lhs_mat_01_3, 0);
2687-
__m256i lhs_mat_23_3 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 112)));
2688-
lhs_mat_23_3 = _mm256_permute2f128_si256(lhs_mat_23_3, lhs_mat_23_3, 0);
2669+
__m256i lhs_mat_0123_0 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs)));
2670+
__m256i lhs_mat_01_0 = _mm256_permute2f128_si256(lhs_mat_0123_0, lhs_mat_0123_0, 0);
2671+
__m256i lhs_mat_23_0 = _mm256_permute2f128_si256(lhs_mat_0123_0, lhs_mat_0123_0, 17);
2672+
__m256i lhs_mat_0123_1 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 32)));
2673+
__m256i lhs_mat_01_1 = _mm256_permute2f128_si256(lhs_mat_0123_1, lhs_mat_0123_1, 0);
2674+
__m256i lhs_mat_23_1 = _mm256_permute2f128_si256(lhs_mat_0123_1, lhs_mat_0123_1, 17);
2675+
__m256i lhs_mat_0123_2 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 64)));
2676+
__m256i lhs_mat_01_2 = _mm256_permute2f128_si256(lhs_mat_0123_2, lhs_mat_0123_2, 0);
2677+
__m256i lhs_mat_23_2 = _mm256_permute2f128_si256(lhs_mat_0123_2, lhs_mat_0123_2, 17);
2678+
__m256i lhs_mat_0123_3 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 96)));
2679+
__m256i lhs_mat_01_3 = _mm256_permute2f128_si256(lhs_mat_0123_3, lhs_mat_0123_3, 0);
2680+
__m256i lhs_mat_23_3 = _mm256_permute2f128_si256(lhs_mat_0123_3, lhs_mat_0123_3, 17);
26892681

26902682
// Shuffle pattern one - left side input
26912683

0 commit comments

Comments
 (0)