Skip to content

Commit 95ec7f0

Browse files
committed
metal : small improvement for Q4_K
1 parent 9268aac commit 95ec7f0

File tree

1 file changed

+11
-17
lines changed

1 file changed

+11
-17
lines changed

ggml-metal.metal

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -971,7 +971,7 @@ kernel void kernel_mul_mat_q4_k_f32(
971971

972972
const uint16_t kmask1 = 0x3f3f;
973973
const uint16_t kmask2 = 0x0f0f;
974-
const uint16_t kmask3 = 0x0303;
974+
const uint16_t kmask3 = 0xc0c0;
975975

976976
const int nb = ne00/QK_K;
977977

@@ -991,16 +991,15 @@ kernel void kernel_mul_mat_q4_k_f32(
991991

992992
const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
993993
const int in = il%2;
994+
const int l0 = n*(2*ir + in);
994995

995996
sum[ith] = 0.0f;
996997

997-
//uchar2 sc1, sc2;
998-
999998
float sumf = 0;
1000999
for (int i = tpitg.x; i < nb; i += tptg.x) {
10011000

1002-
device const uint8_t * q1 = (x + i)->qs + 32*im + n*(2*ir + in);
1003-
device const float * y1 = yy + i*QK_K + 64*im + n*(2*ir + in);
1001+
device const uint8_t * q1 = (x + i)->qs + 32*im + l0;
1002+
device const float * y1 = yy + i*QK_K + 64*im + l0;
10041003
device const uint8_t * q2 = q1 + 64;
10051004
device const float * y2 = y1 + 128;
10061005

@@ -1011,21 +1010,16 @@ kernel void kernel_mul_mat_q4_k_f32(
10111010

10121011
const uchar2 sc1 = as_type<uchar2>((uint16_t)(a[im+0] & kmask1));
10131012
const uchar2 sc2 = as_type<uchar2>((uint16_t)(a[im+2] & kmask1));
1014-
const uchar2 sc3 = as_type<uchar2>((uint16_t)(((a[im+4] >> 0) & kmask2) | (((a[im+0] >> 6) & kmask3) << 4)));
1015-
const uchar2 sc4 = as_type<uchar2>((uint16_t)(((a[im+4] >> 4) & kmask2) | (((a[im+2] >> 6) & kmask3) << 4)));
1013+
const uchar2 sc3 = as_type<uchar2>((uint16_t)(((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2)));
1014+
const uchar2 sc4 = as_type<uchar2>((uint16_t)(((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2)));
10161015

1017-
float4 s1 = {0.f, 0.f, 0.f, 0.f};
1018-
float4 s2 = {0.f, 0.f, 0.f, 0.f};
1016+
float2 s = {0.f, 0.f};
10191017
for (int l = 0; l < n; ++l) {
1020-
s1[0] += y1[l+ 0] * (q1[l] & 0xF); s1[1] += y1[l+ 0];
1021-
s1[2] += y1[l+32] * (q1[l] >> 4); s1[3] += y1[l+32];
1022-
s2[0] += y2[l+ 0] * (q2[l] & 0xF); s2[1] += y2[l+ 0];
1023-
s2[2] += y2[l+32] * (q2[l] >> 4); s2[3] += y2[l+32];
1018+
s[0] += y1[l] * sc1[0] * (q1[l] & 0xF) + y1[l+32] * sc1[1] * (q1[l] >> 4)
1019+
+ y2[l] * sc3[0] * (q2[l] & 0xF) + y2[l+32] * sc3[1] * (q2[l] >> 4);
1020+
s[1] += y1[l] * sc2[0] + y1[l+32] * sc2[1] + y2[l] * sc4[0] + y2[l+32] * sc4[1];
10241021
}
1025-
sumf += dall * (s1[0] * sc1[0] + s1[2] * sc1[1]
1026-
+ s2[0] * sc3[0] + s2[2] * sc3[1])
1027-
- dmin * (s1[1] * sc2[0] + s1[3] * sc2[1]
1028-
+ s2[1] * sc4[0] + s2[3] * sc4[1]);
1022+
sumf += dall * s[0] - dmin * s[1];
10291023

10301024
}
10311025
sum[ith] = sumf;

0 commit comments

Comments
 (0)