@@ -3409,10 +3409,9 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
3409
3409
#if defined(__ARM_FEATURE_MATMUL_INT8)
3410
3410
if (nrc == 2) {
3411
3411
const block_q4_0 * restrict vx0 = vx;
3412
- const block_q4_0 * restrict vx1 = vx + bx;
3413
-
3412
+ const block_q4_0 * restrict vx1 = (const block_q4_0 *) ((const uint8_t*)vx + bx);
3414
3413
const block_q8_0 * restrict vy0 = vy;
3415
- const block_q8_0 * restrict vy1 = vy + by;
3414
+ const block_q8_0 * restrict vy1 = (const block_q8_0 *) ((const uint8_t*) vy + by) ;
3416
3415
3417
3416
float32x4_t sumv0 = vdupq_n_f32(0.0f);
3418
3417
@@ -3446,10 +3445,12 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
3446
3445
const int8x16_t y1_l = vld1q_s8(b_y1->qs);
3447
3446
const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
3448
3447
3449
- float32x4_t scale = {GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d),
3450
- GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d),
3451
- GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d),
3452
- GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)};
3448
+ float32_t _scale[4] = { GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d),
3449
+ GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d),
3450
+ GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d),
3451
+ GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)};
3452
+
3453
+ float32x4_t scale = vld1q_f32(_scale);
3453
3454
3454
3455
int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
3455
3456
int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
@@ -3776,9 +3777,9 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
3776
3777
#if defined(__ARM_FEATURE_MATMUL_INT8)
3777
3778
if (nrc == 2) {
3778
3779
const block_q4_1 * restrict vx0 = vx;
3779
- const block_q4_1 * restrict vx1 = vx + bx;
3780
+ const block_q4_1 * restrict vx1 = (const block_q4_1 *) ((const uint8_t*) vx + bx) ;
3780
3781
const block_q8_1 * restrict vy0 = vy;
3781
- const block_q8_1 * restrict vy1 = vy + by;
3782
+ const block_q8_1 * restrict vy1 = (const block_q8_1 *) ((const uint8_t*) vy + by) ;
3782
3783
3783
3784
float32x4_t sumv0 = vdupq_n_f32(0.0f);
3784
3785
float32x4_t summs0 = vdupq_n_f32(0.0f);
@@ -3789,11 +3790,11 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
3789
3790
const block_q8_1 * restrict b_y0 = &vy0[i];
3790
3791
const block_q8_1 * restrict b_y1 = &vy1[i];
3791
3792
3792
- float32x4_t summs_t = {GGML_FP16_TO_FP32(b_x0->m) * GGML_FP16_TO_FP32(b_y0->s),
3793
- GGML_FP16_TO_FP32(b_x1->m) * GGML_FP16_TO_FP32(b_y0->s),
3794
- GGML_FP16_TO_FP32(b_x0->m) * GGML_FP16_TO_FP32(b_y1->s),
3795
- GGML_FP16_TO_FP32(b_x1->m) * GGML_FP16_TO_FP32(b_y1->s)};
3796
- summs0 += summs_t;
3793
+ float32_t summs_t[4] = {GGML_FP16_TO_FP32(b_x0->m) * GGML_FP16_TO_FP32(b_y0->s),
3794
+ GGML_FP16_TO_FP32(b_x1->m) * GGML_FP16_TO_FP32(b_y0->s),
3795
+ GGML_FP16_TO_FP32(b_x0->m) * GGML_FP16_TO_FP32(b_y1->s),
3796
+ GGML_FP16_TO_FP32(b_x1->m) * GGML_FP16_TO_FP32(b_y1->s)};
3797
+ summs0 = vaddq_f32(summs0, vld1q_f32( summs_t)) ;
3797
3798
3798
3799
const uint8x16_t m4b = vdupq_n_u8(0x0F);
3799
3800
@@ -3813,10 +3814,11 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
3813
3814
const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
3814
3815
3815
3816
// mmla into int32x4_t
3816
- float32x4_t scale = {GGML_FP16_TO_FP32(b_x0->d)*b_y0->d,
3817
- GGML_FP16_TO_FP32(b_x0->d)*b_y1->d,
3818
- GGML_FP16_TO_FP32(b_x1->d)*b_y0->d,
3819
- GGML_FP16_TO_FP32(b_x1->d)*b_y1->d};
3817
+ float32_t _scale[4] = {GGML_FP16_TO_FP32(b_x0->d)*b_y0->d,
3818
+ GGML_FP16_TO_FP32(b_x0->d)*b_y1->d,
3819
+ GGML_FP16_TO_FP32(b_x1->d)*b_y0->d,
3820
+ GGML_FP16_TO_FP32(b_x1->d)*b_y1->d};
3821
+ float32x4_t scale = vld1q_f32(_scale);
3820
3822
3821
3823
int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
3822
3824
int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
@@ -3835,7 +3837,7 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
3835
3837
3836
3838
float32x4_t sumv1 = vextq_f32(sumv0, sumv0, 2);
3837
3839
float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1);
3838
- sumv2 = sumv2 + summs0;
3840
+ sumv2 = vaddq_f32( sumv2, summs0) ;
3839
3841
3840
3842
vst1_f32(s, vget_low_f32(sumv2));
3841
3843
vst1_f32(s + bs, vget_high_f32(sumv2));
@@ -4588,35 +4590,36 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r
4588
4590
4589
4591
#if defined(__ARM_FEATURE_MATMUL_INT8)
4590
4592
if (nrc == 2) {
4591
- const block_q8_0 * restrict vx0 = vx;
4592
- const block_q8_0 * restrict vx1 = vx + bx;
4593
+ const block_q4_0 * restrict vx0 = vx;
4594
+ const block_q4_0 * restrict vx1 = (const block_q4_0 *) ((const uint8_t*) vx + bx) ;
4593
4595
const block_q8_0 * restrict vy0 = vy;
4594
- const block_q8_0 * restrict vy1 = vy + by;
4596
+ const block_q8_0 * restrict vy1 = (const block_q8_0 *) ((const uint8_t*) vy + by) ;
4595
4597
4596
4598
float32x4_t sumv0 = vdupq_n_f32(0.0f);
4597
4599
4598
4600
for (int i = 0; i < nb; i++) {
4599
- const block_q8_0 * restrict b_x0 = &vx0[i];
4601
+ const block_q4_0 * restrict b_x0 = &vx0[i];
4600
4602
const block_q8_0 * restrict b_y0 = &vy0[i];
4601
4603
4602
- const block_q8_0 * restrict b_x1 = &vx1[i];
4604
+ const block_q4_0 * restrict b_x1 = &vx1[i];
4603
4605
const block_q8_0 * restrict b_y1 = &vy1[i];
4604
4606
4605
- const int8x16_t x0_l = vld1q_s8(b_x0->qs);
4606
- const int8x16_t x0_h = vld1q_s8(b_x0->qs + 16);
4607
- const int8x16_t x1_l = vld1q_s8(b_x1->qs);
4608
- const int8x16_t x1_h = vld1q_s8(b_x1->qs + 16);
4607
+ const int8x16_t x0_l = vld1q_s8((const int8_t*) b_x0->qs);
4608
+ const int8x16_t x0_h = vld1q_s8((const int8_t*) b_x0->qs + 16);
4609
+ const int8x16_t x1_l = vld1q_s8((const int8_t*) b_x1->qs);
4610
+ const int8x16_t x1_h = vld1q_s8((const int8_t*) b_x1->qs + 16);
4609
4611
4610
4612
// load y
4611
4613
const int8x16_t y0_l = vld1q_s8(b_y0->qs);
4612
4614
const int8x16_t y0_h = vld1q_s8(b_y0->qs + 16);
4613
4615
const int8x16_t y1_l = vld1q_s8(b_y1->qs);
4614
4616
const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
4615
4617
4616
- float32x4_t scale = {GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d),
4617
- GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d),
4618
- GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d),
4619
- GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)};
4618
+ float32_t _scale[4] = {GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d),
4619
+ GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d),
4620
+ GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d),
4621
+ GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)};
4622
+ float32x4_t scale = vld1q_f32(_scale);
4620
4623
4621
4624
int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
4622
4625
int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
0 commit comments