@@ -5573,7 +5573,88 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r
5573
5573
5574
5574
uint32_t utmp [4 ];
5575
5575
5576
- #ifdef __ARM_NEON
5576
+ #ifdef __ARM_FEATURE_SVE
5577
+ float sumf = 0 ;
5578
+ for (int i = 0 ; i < nb ; ++ i ) {
5579
+
5580
+ const float d = y [i ].d * GGML_FP16_TO_FP32 (x [i ].d );
5581
+ const float dmin = y [i ].d * GGML_FP16_TO_FP32 (x [i ].dmin );
5582
+
5583
+ const int16x8_t q8sums = vpaddq_s16 (vld1q_s16 (y [i ].bsums ), vld1q_s16 (y [i ].bsums + 8 ));
5584
+
5585
+ memcpy (utmp , x [i ].scales , K_SCALE_SIZE );
5586
+
5587
+ uint32x2_t mins8 = { 0 };
5588
+ mins8 = vset_lane_u32 (utmp [1 ] & kmask1 , mins8 , 0 );
5589
+ mins8 = vset_lane_u32 (((utmp [2 ] >> 4 ) & kmask2 ) | (((utmp [1 ] >> 6 ) & kmask3 ) << 4 ), mins8 , 1 );
5590
+
5591
+ utmp [1 ] = (utmp [2 ] & kmask2 ) | (((utmp [0 ] >> 6 ) & kmask3 ) << 4 );
5592
+ utmp [0 ] &= kmask1 ;
5593
+
5594
+ const int16x8_t mins = vreinterpretq_s16_u16 (vmovl_u8 (vreinterpret_u8_u32 (mins8 )));
5595
+ const int32x4_t prod = vaddq_s32 (vmull_s16 (vget_low_s16 (q8sums ), vget_low_s16 (mins )),
5596
+ vmull_s16 (vget_high_s16 (q8sums ), vget_high_s16 (mins )));
5597
+ sumf -= dmin * vaddvq_s32 (prod );
5598
+
5599
+ const uint8_t * scales = (const uint8_t * )utmp ;
5600
+
5601
+ const uint8_t * restrict q4 = x [i ].qs ;
5602
+ const int8_t * restrict q8 = y [i ].qs ;
5603
+
5604
+ const int vector_length = ggml_cpu_get_sve_cnt ()* 8 ;
5605
+ const svuint8_t m4b = svdup_n_u8 (0xf );
5606
+ const svint32_t mzero = svdup_n_s32 (0 );
5607
+ svint32_t sumi1 = svdup_n_s32 (0 );
5608
+ svint32_t sumi1_1 = svdup_n_s32 (0 );
5609
+ svint32_t sumi1_2 = svdup_n_s32 (0 );
5610
+ svint32_t sumi2 = svdup_n_s32 (0 );
5611
+ svint32_t sumi2_1 = svdup_n_s32 (0 );
5612
+ svint32_t sumi2_2 = svdup_n_s32 (0 );
5613
+ switch (vector_length ) {
5614
+ case 128 :
5615
+ {
5616
+ for (int j = 0 ; j < QK_K /64 ; ++ j ) {
5617
+ svint8_t q4bytes = svreinterpret_s8_u8 (svand_u8_x (svptrue_b8 (), svld1_u8 (svptrue_b8 (), q4 ), m4b ));
5618
+ svint8_t q8bytes = svld1_s8 (svptrue_b8 (), q8 ); q8 += 16 ;
5619
+ sumi1_1 = svmla_n_s32_x (svptrue_b32 (), sumi1_1 , svdot_s32 (mzero , q4bytes , q8bytes ), scales [2 * j + 0 ]);
5620
+ q4bytes = svreinterpret_s8_u8 (svand_u8_x (svptrue_b8 (), svld1_u8 (svptrue_b8 (), q4 + 16 ), m4b ));
5621
+ q8bytes = svld1_s8 (svptrue_b8 (), q8 ); q8 += 16 ;
5622
+ sumi1_2 = svmla_n_s32_x (svptrue_b32 (), sumi1_2 , svdot_s32 (mzero , q4bytes , q8bytes ), scales [2 * j + 0 ]);
5623
+
5624
+ q4bytes = svreinterpret_s8_u8 (svlsr_n_u8_x (svptrue_b8 (), svld1_u8 (svptrue_b8 (), q4 ), 4 ));
5625
+ q8bytes = svld1_s8 (svptrue_b8 (), q8 ); q8 += 16 ;
5626
+ sumi2_1 = svmla_n_s32_x (svptrue_b32 (), sumi2_1 , svdot_s32 (mzero , q4bytes , q8bytes ), scales [2 * j + 1 ]);
5627
+ q4bytes = svreinterpret_s8_u8 (svlsr_n_u8_x (svptrue_b8 (), svld1_u8 (svptrue_b8 (), q4 + 16 ), 4 ));
5628
+ q8bytes = svld1_s8 (svptrue_b8 (), q8 ); q8 += 16 ;
5629
+ sumi2_2 = svmla_n_s32_x (svptrue_b32 (), sumi2_2 , svdot_s32 (mzero , q4bytes , q8bytes ), scales [2 * j + 1 ]);
5630
+ q4 += 32 ;
5631
+ }
5632
+ sumi1 = svadd_s32_x (svptrue_b32 (), sumi1_1 , sumi1_2 );
5633
+ sumi2 = svadd_s32_x (svptrue_b32 (), sumi2_1 , sumi2_2 );
5634
+ sumf += d * (svaddv_s32 (svptrue_b32 (), svadd_s32_x (svptrue_b32 (), sumi1 , sumi2 )));
5635
+ } break ;
5636
+ case 256 :
5637
+ case 512 :
5638
+ {
5639
+ for (int j = 0 ; j < QK_K /64 ; ++ j ) {
5640
+ const svuint8_t q4bits = svld1_u8 (svptrue_pat_b8 (SV_VL32 ), q4 ); q4 += 32 ;
5641
+ svint8_t q4bytes = svreinterpret_s8_u8 (svand_u8_x (svptrue_pat_b8 (SV_VL32 ), q4bits , m4b ));
5642
+ svint8_t q8bytes = svld1_s8 (svptrue_pat_b8 (SV_VL32 ), q8 ); q8 += 32 ;
5643
+ sumi1 = svmla_n_s32_x (svptrue_pat_b32 (SV_VL8 ), sumi1 , svdot_s32 (mzero , q4bytes , q8bytes ), scales [2 * j + 0 ]);
5644
+
5645
+ q4bytes = svreinterpret_s8_u8 (svlsr_n_u8_x (svptrue_pat_b8 (SV_VL32 ), q4bits , 4 ));
5646
+ q8bytes = svld1_s8 (svptrue_pat_b8 (SV_VL32 ), q8 ); q8 += 32 ;
5647
+ sumi2 = svmla_n_s32_x (svptrue_pat_b32 (SV_VL8 ), sumi2 , svdot_s32 (mzero , q4bytes , q8bytes ), scales [2 * j + 1 ]);
5648
+ }
5649
+ sumf += d * (svaddv_s32 (svptrue_pat_b32 (SV_VL8 ), svadd_s32_x (svptrue_pat_b32 (SV_VL8 ), sumi1 , sumi2 )));
5650
+ } break ;
5651
+ default :
5652
+ assert (false && "Unsupported vector length" );
5653
+ break ;
5654
+ }
5655
+ }
5656
+ * s = sumf ;
5657
+ #elif __ARM_NEON
5577
5658
const uint8x16_t m4b = vdupq_n_u8 (0xf );
5578
5659
const int32x4_t mzero = vdupq_n_s32 (0 );
5579
5660
0 commit comments