@@ -3415,10 +3415,9 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
3415
3415
#if defined(__ARM_FEATURE_MATMUL_INT8)
3416
3416
if (nrc == 2) {
3417
3417
const block_q4_0 * restrict vx0 = vx;
3418
- const block_q4_0 * restrict vx1 = vx + bx;
3419
-
3418
+ const block_q4_0 * restrict vx1 = (const block_q4_0 *) ((const uint8_t*)vx + bx);
3420
3419
const block_q8_0 * restrict vy0 = vy;
3421
- const block_q8_0 * restrict vy1 = vy + by;
3420
+ const block_q8_0 * restrict vy1 = (const block_q8_0 *) ((const uint8_t*) vy + by) ;
3422
3421
3423
3422
float32x4_t sumv0 = vdupq_n_f32(0.0f);
3424
3423
@@ -3452,10 +3451,12 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
3452
3451
const int8x16_t y1_l = vld1q_s8(b_y1->qs);
3453
3452
const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
3454
3453
3455
- float32x4_t scale = {GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d),
3456
- GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d),
3457
- GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d),
3458
- GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)};
3454
+ float32_t _scale[4] = { GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d),
3455
+ GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d),
3456
+ GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d),
3457
+ GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)};
3458
+
3459
+ float32x4_t scale = vld1q_f32(_scale);
3459
3460
3460
3461
int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
3461
3462
int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
@@ -3782,9 +3783,9 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
3782
3783
#if defined(__ARM_FEATURE_MATMUL_INT8)
3783
3784
if (nrc == 2) {
3784
3785
const block_q4_1 * restrict vx0 = vx;
3785
- const block_q4_1 * restrict vx1 = vx + bx;
3786
+ const block_q4_1 * restrict vx1 = (const block_q4_1 *) ((const uint8_t*) vx + bx) ;
3786
3787
const block_q8_1 * restrict vy0 = vy;
3787
- const block_q8_1 * restrict vy1 = vy + by;
3788
+ const block_q8_1 * restrict vy1 = (const block_q8_1 *) ((const uint8_t*) vy + by) ;
3788
3789
3789
3790
float32x4_t sumv0 = vdupq_n_f32(0.0f);
3790
3791
float32x4_t summs0 = vdupq_n_f32(0.0f);
@@ -3795,11 +3796,11 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
3795
3796
const block_q8_1 * restrict b_y0 = &vy0[i];
3796
3797
const block_q8_1 * restrict b_y1 = &vy1[i];
3797
3798
3798
- float32x4_t summs_t = {GGML_FP16_TO_FP32(b_x0->m) * GGML_FP16_TO_FP32(b_y0->s),
3799
- GGML_FP16_TO_FP32(b_x1->m) * GGML_FP16_TO_FP32(b_y0->s),
3800
- GGML_FP16_TO_FP32(b_x0->m) * GGML_FP16_TO_FP32(b_y1->s),
3801
- GGML_FP16_TO_FP32(b_x1->m) * GGML_FP16_TO_FP32(b_y1->s)};
3802
- summs0 += summs_t;
3799
+ float32_t summs_t[4] = {GGML_FP16_TO_FP32(b_x0->m) * GGML_FP16_TO_FP32(b_y0->s),
3800
+ GGML_FP16_TO_FP32(b_x1->m) * GGML_FP16_TO_FP32(b_y0->s),
3801
+ GGML_FP16_TO_FP32(b_x0->m) * GGML_FP16_TO_FP32(b_y1->s),
3802
+ GGML_FP16_TO_FP32(b_x1->m) * GGML_FP16_TO_FP32(b_y1->s)};
3803
+ summs0 = vaddq_f32(summs0, vld1q_f32( summs_t)) ;
3803
3804
3804
3805
const uint8x16_t m4b = vdupq_n_u8(0x0F);
3805
3806
@@ -3819,10 +3820,11 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
3819
3820
const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
3820
3821
3821
3822
// mmla into int32x4_t
3822
- float32x4_t scale = {GGML_FP16_TO_FP32(b_x0->d)*b_y0->d,
3823
- GGML_FP16_TO_FP32(b_x0->d)*b_y1->d,
3824
- GGML_FP16_TO_FP32(b_x1->d)*b_y0->d,
3825
- GGML_FP16_TO_FP32(b_x1->d)*b_y1->d};
3823
+ float32_t _scale[4] = {GGML_FP16_TO_FP32(b_x0->d)*b_y0->d,
3824
+ GGML_FP16_TO_FP32(b_x0->d)*b_y1->d,
3825
+ GGML_FP16_TO_FP32(b_x1->d)*b_y0->d,
3826
+ GGML_FP16_TO_FP32(b_x1->d)*b_y1->d};
3827
+ float32x4_t scale = vld1q_f32(_scale);
3826
3828
3827
3829
int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
3828
3830
int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
@@ -3841,7 +3843,7 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
3841
3843
3842
3844
float32x4_t sumv1 = vextq_f32(sumv0, sumv0, 2);
3843
3845
float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1);
3844
- sumv2 = sumv2 + summs0;
3846
+ sumv2 = vaddq_f32( sumv2, summs0) ;
3845
3847
3846
3848
vst1_f32(s, vget_low_f32(sumv2));
3847
3849
vst1_f32(s + bs, vget_high_f32(sumv2));
@@ -4594,35 +4596,36 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r
4594
4596
4595
4597
#if defined(__ARM_FEATURE_MATMUL_INT8)
4596
4598
if (nrc == 2) {
4597
- const block_q8_0 * restrict vx0 = vx;
4598
- const block_q8_0 * restrict vx1 = vx + bx;
4599
+ const block_q4_0 * restrict vx0 = vx;
4600
+ const block_q4_0 * restrict vx1 = (const block_q4_0 *) ((const uint8_t*) vx + bx) ;
4599
4601
const block_q8_0 * restrict vy0 = vy;
4600
- const block_q8_0 * restrict vy1 = vy + by;
4602
+ const block_q8_0 * restrict vy1 = (const block_q8_0 *) ((const uint8_t*) vy + by) ;
4601
4603
4602
4604
float32x4_t sumv0 = vdupq_n_f32(0.0f);
4603
4605
4604
4606
for (int i = 0; i < nb; i++) {
4605
- const block_q8_0 * restrict b_x0 = &vx0[i];
4607
+ const block_q4_0 * restrict b_x0 = &vx0[i];
4606
4608
const block_q8_0 * restrict b_y0 = &vy0[i];
4607
4609
4608
- const block_q8_0 * restrict b_x1 = &vx1[i];
4610
+ const block_q4_0 * restrict b_x1 = &vx1[i];
4609
4611
const block_q8_0 * restrict b_y1 = &vy1[i];
4610
4612
4611
- const int8x16_t x0_l = vld1q_s8(b_x0->qs);
4612
- const int8x16_t x0_h = vld1q_s8(b_x0->qs + 16);
4613
- const int8x16_t x1_l = vld1q_s8(b_x1->qs);
4614
- const int8x16_t x1_h = vld1q_s8(b_x1->qs + 16);
4613
+ const int8x16_t x0_l = vld1q_s8((const int8_t*) b_x0->qs);
4614
+ const int8x16_t x0_h = vld1q_s8((const int8_t*) b_x0->qs + 16);
4615
+ const int8x16_t x1_l = vld1q_s8((const int8_t*) b_x1->qs);
4616
+ const int8x16_t x1_h = vld1q_s8((const int8_t*) b_x1->qs + 16);
4615
4617
4616
4618
// load y
4617
4619
const int8x16_t y0_l = vld1q_s8(b_y0->qs);
4618
4620
const int8x16_t y0_h = vld1q_s8(b_y0->qs + 16);
4619
4621
const int8x16_t y1_l = vld1q_s8(b_y1->qs);
4620
4622
const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
4621
4623
4622
- float32x4_t scale = {GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d),
4623
- GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d),
4624
- GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d),
4625
- GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)};
4624
+ float32_t _scale[4] = {GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d),
4625
+ GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d),
4626
+ GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d),
4627
+ GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)};
4628
+ float32x4_t scale = vld1q_f32(_scale);
4626
4629
4627
4630
int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
4628
4631
int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
0 commit comments