Skip to content

Commit 9766c0d

Browse files
matmul-int8: enable matmul-int8 with MSVC and fix Clang warnings
1 parent 91515cb commit 9766c0d

File tree

2 files changed

+40
-32
lines changed

2 files changed

+40
-32
lines changed

CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -994,6 +994,11 @@ if (CMAKE_OSX_ARCHITECTURES STREQUAL "arm64" OR CMAKE_GENERATOR_PLATFORM_LWR STR
994994
if (GGML_COMPILER_SUPPORT_DOTPROD)
995995
add_compile_definitions(__ARM_FEATURE_DOTPROD)
996996
endif ()
997+
check_cxx_source_compiles("#include <arm_neon.h>\nint main() { int8x16_t _a, _b; int32x4_t _s = vmlaq_f32(_s, _a, _b); return 0; }" GGML_COMPILER_SUPPORT_MATMUL_INT8)
998+
if (GGML_COMPILER_SUPPORT_MATMUL_INT8)
999+
add_compile_definitions(__ARM_FEATURE_MATMUL_INT8)
1000+
endif ()
1001+
9971002
check_cxx_source_compiles("#include <arm_neon.h>\nint main() { float16_t _a; float16x8_t _s = vdupq_n_f16(_a); return 0; }" GGML_COMPILER_SUPPORT_FP16_VECTOR_ARITHMETIC)
9981003
if (GGML_COMPILER_SUPPORT_FP16_VECTOR_ARITHMETIC)
9991004
add_compile_definitions(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)

ggml-quants.c

Lines changed: 35 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -3409,10 +3409,9 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
34093409
#if defined(__ARM_FEATURE_MATMUL_INT8)
34103410
if (nrc == 2) {
34113411
const block_q4_0 * restrict vx0 = vx;
3412-
const block_q4_0 * restrict vx1 = vx + bx;
3413-
3412+
const block_q4_0 * restrict vx1 = (const block_q4_0 *) ((const uint8_t*)vx + bx);
34143413
const block_q8_0 * restrict vy0 = vy;
3415-
const block_q8_0 * restrict vy1 = vy + by;
3414+
const block_q8_0 * restrict vy1 = (const block_q8_0 *) ((const uint8_t*)vy + by);
34163415

34173416
float32x4_t sumv0 = vdupq_n_f32(0.0f);
34183417

@@ -3446,10 +3445,12 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
34463445
const int8x16_t y1_l = vld1q_s8(b_y1->qs);
34473446
const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
34483447

3449-
float32x4_t scale = {GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d),
3450-
GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d),
3451-
GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d),
3452-
GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)};
3448+
float32_t _scale[4] = { GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d),
3449+
GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d),
3450+
GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d),
3451+
GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)};
3452+
3453+
float32x4_t scale = vld1q_f32(_scale);
34533454

34543455
int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
34553456
int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
@@ -3776,9 +3777,9 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
37763777
#if defined(__ARM_FEATURE_MATMUL_INT8)
37773778
if (nrc == 2) {
37783779
const block_q4_1 * restrict vx0 = vx;
3779-
const block_q4_1 * restrict vx1 = vx + bx;
3780+
const block_q4_1 * restrict vx1 = (const block_q4_1 *) ((const uint8_t*)vx + bx);
37803781
const block_q8_1 * restrict vy0 = vy;
3781-
const block_q8_1 * restrict vy1 = vy + by;
3782+
const block_q8_1 * restrict vy1 = (const block_q8_1 *) ((const uint8_t*)vy + by);
37823783

37833784
float32x4_t sumv0 = vdupq_n_f32(0.0f);
37843785
float32x4_t summs0 = vdupq_n_f32(0.0f);
@@ -3789,11 +3790,11 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
37893790
const block_q8_1 * restrict b_y0 = &vy0[i];
37903791
const block_q8_1 * restrict b_y1 = &vy1[i];
37913792

3792-
float32x4_t summs_t = {GGML_FP16_TO_FP32(b_x0->m) * GGML_FP16_TO_FP32(b_y0->s),
3793-
GGML_FP16_TO_FP32(b_x1->m) * GGML_FP16_TO_FP32(b_y0->s),
3794-
GGML_FP16_TO_FP32(b_x0->m) * GGML_FP16_TO_FP32(b_y1->s),
3795-
GGML_FP16_TO_FP32(b_x1->m) * GGML_FP16_TO_FP32(b_y1->s)};
3796-
summs0 += summs_t;
3793+
float32_t summs_t[4] = {GGML_FP16_TO_FP32(b_x0->m) * GGML_FP16_TO_FP32(b_y0->s),
3794+
GGML_FP16_TO_FP32(b_x1->m) * GGML_FP16_TO_FP32(b_y0->s),
3795+
GGML_FP16_TO_FP32(b_x0->m) * GGML_FP16_TO_FP32(b_y1->s),
3796+
GGML_FP16_TO_FP32(b_x1->m) * GGML_FP16_TO_FP32(b_y1->s)};
3797+
summs0 = vaddq_f32(summs0, vld1q_f32(summs_t));
37973798

37983799
const uint8x16_t m4b = vdupq_n_u8(0x0F);
37993800

@@ -3813,10 +3814,11 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
38133814
const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
38143815

38153816
// mmla into int32x4_t
3816-
float32x4_t scale = {GGML_FP16_TO_FP32(b_x0->d)*b_y0->d,
3817-
GGML_FP16_TO_FP32(b_x0->d)*b_y1->d,
3818-
GGML_FP16_TO_FP32(b_x1->d)*b_y0->d,
3819-
GGML_FP16_TO_FP32(b_x1->d)*b_y1->d};
3817+
float32_t _scale[4] = {GGML_FP16_TO_FP32(b_x0->d)*b_y0->d,
3818+
GGML_FP16_TO_FP32(b_x0->d)*b_y1->d,
3819+
GGML_FP16_TO_FP32(b_x1->d)*b_y0->d,
3820+
GGML_FP16_TO_FP32(b_x1->d)*b_y1->d};
3821+
float32x4_t scale = vld1q_f32(_scale);
38203822

38213823
int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
38223824
int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
@@ -3835,7 +3837,7 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
38353837

38363838
float32x4_t sumv1 = vextq_f32(sumv0, sumv0, 2);
38373839
float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1);
3838-
sumv2 = sumv2 + summs0;
3840+
sumv2 = vaddq_f32(sumv2, summs0);
38393841

38403842
vst1_f32(s, vget_low_f32(sumv2));
38413843
vst1_f32(s + bs, vget_high_f32(sumv2));
@@ -4588,35 +4590,36 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r
45884590

45894591
#if defined(__ARM_FEATURE_MATMUL_INT8)
45904592
if (nrc == 2) {
4591-
const block_q8_0 * restrict vx0 = vx;
4592-
const block_q8_0 * restrict vx1 = vx + bx;
4593+
const block_q4_0 * restrict vx0 = vx;
4594+
const block_q4_0 * restrict vx1 = (const block_q4_0 *) ((const uint8_t*)vx + bx);
45934595
const block_q8_0 * restrict vy0 = vy;
4594-
const block_q8_0 * restrict vy1 = vy + by;
4596+
const block_q8_0 * restrict vy1 = (const block_q8_0 *) ((const uint8_t*)vy + by);
45954597

45964598
float32x4_t sumv0 = vdupq_n_f32(0.0f);
45974599

45984600
for (int i = 0; i < nb; i++) {
4599-
const block_q8_0 * restrict b_x0 = &vx0[i];
4601+
const block_q4_0 * restrict b_x0 = &vx0[i];
46004602
const block_q8_0 * restrict b_y0 = &vy0[i];
46014603

4602-
const block_q8_0 * restrict b_x1 = &vx1[i];
4604+
const block_q4_0 * restrict b_x1 = &vx1[i];
46034605
const block_q8_0 * restrict b_y1 = &vy1[i];
46044606

4605-
const int8x16_t x0_l = vld1q_s8(b_x0->qs);
4606-
const int8x16_t x0_h = vld1q_s8(b_x0->qs + 16);
4607-
const int8x16_t x1_l = vld1q_s8(b_x1->qs);
4608-
const int8x16_t x1_h = vld1q_s8(b_x1->qs + 16);
4607+
const int8x16_t x0_l = vld1q_s8((const int8_t*)b_x0->qs);
4608+
const int8x16_t x0_h = vld1q_s8((const int8_t*)b_x0->qs + 16);
4609+
const int8x16_t x1_l = vld1q_s8((const int8_t*)b_x1->qs);
4610+
const int8x16_t x1_h = vld1q_s8((const int8_t*)b_x1->qs + 16);
46094611

46104612
// load y
46114613
const int8x16_t y0_l = vld1q_s8(b_y0->qs);
46124614
const int8x16_t y0_h = vld1q_s8(b_y0->qs + 16);
46134615
const int8x16_t y1_l = vld1q_s8(b_y1->qs);
46144616
const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
46154617

4616-
float32x4_t scale = {GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d),
4617-
GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d),
4618-
GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d),
4619-
GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)};
4618+
float32_t _scale[4] = {GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d),
4619+
GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d),
4620+
GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d),
4621+
GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)};
4622+
float32x4_t scale = vld1q_f32(_scale);
46204623

46214624
int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
46224625
int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));

0 commit comments

Comments
 (0)