Skip to content

Commit 81f1802

Browse files
matmul-int8: enable matmul-int8 with MSVC and fix Clang warnings
1 parent 8a0b854 commit 81f1802

File tree

2 files changed

+40
-32
lines changed

2 files changed

+40
-32
lines changed

CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -995,6 +995,11 @@ if (CMAKE_OSX_ARCHITECTURES STREQUAL "arm64" OR CMAKE_GENERATOR_PLATFORM_LWR STR
995995
if (GGML_COMPILER_SUPPORT_DOTPROD)
996996
add_compile_definitions(__ARM_FEATURE_DOTPROD)
997997
endif ()
998+
check_cxx_source_compiles("#include <arm_neon.h>\nint main() { int8x16_t _a, _b; int32x4_t _s = vmlaq_f32(_s, _a, _b); return 0; }" GGML_COMPILER_SUPPORT_MATMUL_INT8)
999+
if (GGML_COMPILER_SUPPORT_MATMUL_INT8)
1000+
add_compile_definitions(__ARM_FEATURE_MATMUL_INT8)
1001+
endif ()
1002+
9981003
check_cxx_source_compiles("#include <arm_neon.h>\nint main() { float16_t _a; float16x8_t _s = vdupq_n_f16(_a); return 0; }" GGML_COMPILER_SUPPORT_FP16_VECTOR_ARITHMETIC)
9991004
if (GGML_COMPILER_SUPPORT_FP16_VECTOR_ARITHMETIC)
10001005
add_compile_definitions(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)

ggml-quants.c

Lines changed: 35 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -3415,10 +3415,9 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
34153415
#if defined(__ARM_FEATURE_MATMUL_INT8)
34163416
if (nrc == 2) {
34173417
const block_q4_0 * restrict vx0 = vx;
3418-
const block_q4_0 * restrict vx1 = vx + bx;
3419-
3418+
const block_q4_0 * restrict vx1 = (const block_q4_0 *) ((const uint8_t*)vx + bx);
34203419
const block_q8_0 * restrict vy0 = vy;
3421-
const block_q8_0 * restrict vy1 = vy + by;
3420+
const block_q8_0 * restrict vy1 = (const block_q8_0 *) ((const uint8_t*)vy + by);
34223421

34233422
float32x4_t sumv0 = vdupq_n_f32(0.0f);
34243423

@@ -3452,10 +3451,12 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
34523451
const int8x16_t y1_l = vld1q_s8(b_y1->qs);
34533452
const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
34543453

3455-
float32x4_t scale = {GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d),
3456-
GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d),
3457-
GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d),
3458-
GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)};
3454+
float32_t _scale[4] = { GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d),
3455+
GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d),
3456+
GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d),
3457+
GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)};
3458+
3459+
float32x4_t scale = vld1q_f32(_scale);
34593460

34603461
int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
34613462
int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
@@ -3782,9 +3783,9 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
37823783
#if defined(__ARM_FEATURE_MATMUL_INT8)
37833784
if (nrc == 2) {
37843785
const block_q4_1 * restrict vx0 = vx;
3785-
const block_q4_1 * restrict vx1 = vx + bx;
3786+
const block_q4_1 * restrict vx1 = (const block_q4_1 *) ((const uint8_t*)vx + bx);
37863787
const block_q8_1 * restrict vy0 = vy;
3787-
const block_q8_1 * restrict vy1 = vy + by;
3788+
const block_q8_1 * restrict vy1 = (const block_q8_1 *) ((const uint8_t*)vy + by);
37883789

37893790
float32x4_t sumv0 = vdupq_n_f32(0.0f);
37903791
float32x4_t summs0 = vdupq_n_f32(0.0f);
@@ -3795,11 +3796,11 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
37953796
const block_q8_1 * restrict b_y0 = &vy0[i];
37963797
const block_q8_1 * restrict b_y1 = &vy1[i];
37973798

3798-
float32x4_t summs_t = {GGML_FP16_TO_FP32(b_x0->m) * GGML_FP16_TO_FP32(b_y0->s),
3799-
GGML_FP16_TO_FP32(b_x1->m) * GGML_FP16_TO_FP32(b_y0->s),
3800-
GGML_FP16_TO_FP32(b_x0->m) * GGML_FP16_TO_FP32(b_y1->s),
3801-
GGML_FP16_TO_FP32(b_x1->m) * GGML_FP16_TO_FP32(b_y1->s)};
3802-
summs0 += summs_t;
3799+
float32_t summs_t[4] = {GGML_FP16_TO_FP32(b_x0->m) * GGML_FP16_TO_FP32(b_y0->s),
3800+
GGML_FP16_TO_FP32(b_x1->m) * GGML_FP16_TO_FP32(b_y0->s),
3801+
GGML_FP16_TO_FP32(b_x0->m) * GGML_FP16_TO_FP32(b_y1->s),
3802+
GGML_FP16_TO_FP32(b_x1->m) * GGML_FP16_TO_FP32(b_y1->s)};
3803+
summs0 = vaddq_f32(summs0, vld1q_f32(summs_t));
38033804

38043805
const uint8x16_t m4b = vdupq_n_u8(0x0F);
38053806

@@ -3819,10 +3820,11 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
38193820
const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
38203821

38213822
// mmla into int32x4_t
3822-
float32x4_t scale = {GGML_FP16_TO_FP32(b_x0->d)*b_y0->d,
3823-
GGML_FP16_TO_FP32(b_x0->d)*b_y1->d,
3824-
GGML_FP16_TO_FP32(b_x1->d)*b_y0->d,
3825-
GGML_FP16_TO_FP32(b_x1->d)*b_y1->d};
3823+
float32_t _scale[4] = {GGML_FP16_TO_FP32(b_x0->d)*b_y0->d,
3824+
GGML_FP16_TO_FP32(b_x0->d)*b_y1->d,
3825+
GGML_FP16_TO_FP32(b_x1->d)*b_y0->d,
3826+
GGML_FP16_TO_FP32(b_x1->d)*b_y1->d};
3827+
float32x4_t scale = vld1q_f32(_scale);
38263828

38273829
int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
38283830
int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
@@ -3841,7 +3843,7 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
38413843

38423844
float32x4_t sumv1 = vextq_f32(sumv0, sumv0, 2);
38433845
float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1);
3844-
sumv2 = sumv2 + summs0;
3846+
sumv2 = vaddq_f32(sumv2, summs0);
38453847

38463848
vst1_f32(s, vget_low_f32(sumv2));
38473849
vst1_f32(s + bs, vget_high_f32(sumv2));
@@ -4594,35 +4596,36 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r
45944596

45954597
#if defined(__ARM_FEATURE_MATMUL_INT8)
45964598
if (nrc == 2) {
4597-
const block_q8_0 * restrict vx0 = vx;
4598-
const block_q8_0 * restrict vx1 = vx + bx;
4599+
const block_q4_0 * restrict vx0 = vx;
4600+
const block_q4_0 * restrict vx1 = (const block_q4_0 *) ((const uint8_t*)vx + bx);
45994601
const block_q8_0 * restrict vy0 = vy;
4600-
const block_q8_0 * restrict vy1 = vy + by;
4602+
const block_q8_0 * restrict vy1 = (const block_q8_0 *) ((const uint8_t*)vy + by);
46014603

46024604
float32x4_t sumv0 = vdupq_n_f32(0.0f);
46034605

46044606
for (int i = 0; i < nb; i++) {
4605-
const block_q8_0 * restrict b_x0 = &vx0[i];
4607+
const block_q4_0 * restrict b_x0 = &vx0[i];
46064608
const block_q8_0 * restrict b_y0 = &vy0[i];
46074609

4608-
const block_q8_0 * restrict b_x1 = &vx1[i];
4610+
const block_q4_0 * restrict b_x1 = &vx1[i];
46094611
const block_q8_0 * restrict b_y1 = &vy1[i];
46104612

4611-
const int8x16_t x0_l = vld1q_s8(b_x0->qs);
4612-
const int8x16_t x0_h = vld1q_s8(b_x0->qs + 16);
4613-
const int8x16_t x1_l = vld1q_s8(b_x1->qs);
4614-
const int8x16_t x1_h = vld1q_s8(b_x1->qs + 16);
4613+
const int8x16_t x0_l = vld1q_s8((const int8_t*)b_x0->qs);
4614+
const int8x16_t x0_h = vld1q_s8((const int8_t*)b_x0->qs + 16);
4615+
const int8x16_t x1_l = vld1q_s8((const int8_t*)b_x1->qs);
4616+
const int8x16_t x1_h = vld1q_s8((const int8_t*)b_x1->qs + 16);
46154617

46164618
// load y
46174619
const int8x16_t y0_l = vld1q_s8(b_y0->qs);
46184620
const int8x16_t y0_h = vld1q_s8(b_y0->qs + 16);
46194621
const int8x16_t y1_l = vld1q_s8(b_y1->qs);
46204622
const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
46214623

4622-
float32x4_t scale = {GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d),
4623-
GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d),
4624-
GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d),
4625-
GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)};
4624+
float32_t _scale[4] = {GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d),
4625+
GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d),
4626+
GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d),
4627+
GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)};
4628+
float32x4_t scale = vld1q_f32(_scale);
46264629

46274630
int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
46284631
int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));

0 commit comments

Comments
 (0)