Skip to content

Commit 9268aac

Browse files
committed
metal : improve q4_K
28.3 -> 26.0 ms/token by avoiding a branch in the calculation of the scales.
1 parent 17c10ac commit 9268aac

File tree

1 file changed

+30
-10
lines changed

1 file changed

+30
-10
lines changed

ggml-metal.metal

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -969,6 +969,10 @@ kernel void kernel_mul_mat_q4_k_f32(
969969
uint2 tpitg[[thread_position_in_threadgroup]],
970970
uint2 tptg[[threads_per_threadgroup]]) {
971971

972+
const uint16_t kmask1 = 0x3f3f;
973+
const uint16_t kmask2 = 0x0f0f;
974+
const uint16_t kmask3 = 0x0303;
975+
972976
const int nb = ne00/QK_K;
973977

974978
const int64_t r0 = tgpig.x;
@@ -983,29 +987,45 @@ kernel void kernel_mul_mat_q4_k_f32(
983987
const int tid = tpitg.y; // 0...16
984988
const int il = tid/4; // 0...3
985989
const int ir = tid%4; // 0...3
986-
const int n = 8;
987-
const int is = 2*il;
990+
const int n = 4;
991+
992+
const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
993+
const int in = il%2;
988994

989995
sum[ith] = 0.0f;
990996

997+
//uchar2 sc1, sc2;
998+
991999
float sumf = 0;
9921000
for (int i = tpitg.x; i < nb; i += tptg.x) {
9931001

994-
device const uint8_t * q = (x + i)->qs + 32*il + n*ir;
995-
device const float * y = yy + i*QK_K + 64*il + n*ir;
996-
device const uint8_t * scales = (x + i)->scales;
1002+
device const uint8_t * q1 = (x + i)->qs + 32*im + n*(2*ir + in);
1003+
device const float * y1 = yy + i*QK_K + 64*im + n*(2*ir + in);
1004+
device const uint8_t * q2 = q1 + 64;
1005+
device const float * y2 = y1 + 128;
1006+
1007+
device const uint16_t * a = (device const uint16_t *)(x + i)->scales;
9971008

9981009
const float dall = (float)((x + i)->d);
9991010
const float dmin = (float)((x + i)->dmin);
10001011

1001-
const uchar4 sc = get_scale_min_k4(is, scales);
1012+
const uchar2 sc1 = as_type<uchar2>((uint16_t)(a[im+0] & kmask1));
1013+
const uchar2 sc2 = as_type<uchar2>((uint16_t)(a[im+2] & kmask1));
1014+
const uchar2 sc3 = as_type<uchar2>((uint16_t)(((a[im+4] >> 0) & kmask2) | (((a[im+0] >> 6) & kmask3) << 4)));
1015+
const uchar2 sc4 = as_type<uchar2>((uint16_t)(((a[im+4] >> 4) & kmask2) | (((a[im+2] >> 6) & kmask3) << 4)));
10021016

1003-
float4 s = {0.f, 0.f, 0.f, 0.f};
1017+
float4 s1 = {0.f, 0.f, 0.f, 0.f};
1018+
float4 s2 = {0.f, 0.f, 0.f, 0.f};
10041019
for (int l = 0; l < n; ++l) {
1005-
s[0] += y[l+ 0] * (q[l] & 0xF); s[1] += y[l+ 0];
1006-
s[2] += y[l+32] * (q[l] >> 4); s[3] += y[l+32];
1020+
s1[0] += y1[l+ 0] * (q1[l] & 0xF); s1[1] += y1[l+ 0];
1021+
s1[2] += y1[l+32] * (q1[l] >> 4); s1[3] += y1[l+32];
1022+
s2[0] += y2[l+ 0] * (q2[l] & 0xF); s2[1] += y2[l+ 0];
1023+
s2[2] += y2[l+32] * (q2[l] >> 4); s2[3] += y2[l+32];
10071024
}
1008-
sumf += dall * (s[0] * sc[0] + s[2] * sc[2]) - dmin * (s[1] * sc[1] + s[3] * sc[3]);
1025+
sumf += dall * (s1[0] * sc1[0] + s1[2] * sc1[1]
1026+
+ s2[0] * sc3[0] + s2[2] * sc3[1])
1027+
- dmin * (s1[1] * sc2[0] + s1[3] * sc2[1]
1028+
+ s2[1] * sc4[0] + s2[3] * sc4[1]);
10091029

10101030
}
10111031
sum[ith] = sumf;

0 commit comments

Comments
 (0)