@@ -969,6 +969,10 @@ kernel void kernel_mul_mat_q4_k_f32(
969
969
uint2 tpitg[[thread_position_in_threadgroup]],
970
970
uint2 tptg[[threads_per_threadgroup]]) {
971
971
972
+ const uint16_t kmask1 = 0x3f3f ;
973
+ const uint16_t kmask2 = 0x0f0f ;
974
+ const uint16_t kmask3 = 0x0303 ;
975
+
972
976
const int nb = ne00/QK_K;
973
977
974
978
const int64_t r0 = tgpig.x ;
@@ -983,29 +987,45 @@ kernel void kernel_mul_mat_q4_k_f32(
983
987
const int tid = tpitg.y ; // 0...16
984
988
const int il = tid/4 ; // 0...3
985
989
const int ir = tid%4 ; // 0...3
986
- const int n = 8 ;
987
- const int is = 2 *il;
990
+ const int n = 4 ;
991
+
992
+ const int im = il/2 ; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
993
+ const int in = il%2 ;
988
994
989
995
sum[ith] = 0 .0f ;
990
996
997
+ // uchar2 sc1, sc2;
998
+
991
999
float sumf = 0 ;
992
1000
for (int i = tpitg.x ; i < nb; i += tptg.x ) {
993
1001
994
- device const uint8_t * q = (x + i)->qs + 32 *il + n*ir;
995
- device const float * y = yy + i*QK_K + 64 *il + n*ir;
996
- device const uint8_t * scales = (x + i)->scales ;
1002
+ device const uint8_t * q1 = (x + i)->qs + 32 *im + n*(2 *ir + in);
1003
+ device const float * y1 = yy + i*QK_K + 64 *im + n*(2 *ir + in);
1004
+ device const uint8_t * q2 = q1 + 64 ;
1005
+ device const float * y2 = y1 + 128 ;
1006
+
1007
+ device const uint16_t * a = (device const uint16_t *)(x + i)->scales ;
997
1008
998
1009
const float dall = (float )((x + i)->d );
999
1010
const float dmin = (float )((x + i)->dmin );
1000
1011
1001
- const uchar4 sc = get_scale_min_k4 (is, scales);
1012
+ const uchar2 sc1 = as_type<uchar2>((uint16_t )(a[im+0 ] & kmask1));
1013
+ const uchar2 sc2 = as_type<uchar2>((uint16_t )(a[im+2 ] & kmask1));
1014
+ const uchar2 sc3 = as_type<uchar2>((uint16_t )(((a[im+4 ] >> 0 ) & kmask2) | (((a[im+0 ] >> 6 ) & kmask3) << 4 )));
1015
+ const uchar2 sc4 = as_type<uchar2>((uint16_t )(((a[im+4 ] >> 4 ) & kmask2) | (((a[im+2 ] >> 6 ) & kmask3) << 4 )));
1002
1016
1003
- float4 s = {0 .f , 0 .f , 0 .f , 0 .f };
1017
+ float4 s1 = {0 .f , 0 .f , 0 .f , 0 .f };
1018
+ float4 s2 = {0 .f , 0 .f , 0 .f , 0 .f };
1004
1019
for (int l = 0 ; l < n; ++l) {
1005
- s[0 ] += y[l+ 0 ] * (q[l] & 0xF ); s[1 ] += y[l+ 0 ];
1006
- s[2 ] += y[l+32 ] * (q[l] >> 4 ); s[3 ] += y[l+32 ];
1020
+ s1[0 ] += y1[l+ 0 ] * (q1[l] & 0xF ); s1[1 ] += y1[l+ 0 ];
1021
+ s1[2 ] += y1[l+32 ] * (q1[l] >> 4 ); s1[3 ] += y1[l+32 ];
1022
+ s2[0 ] += y2[l+ 0 ] * (q2[l] & 0xF ); s2[1 ] += y2[l+ 0 ];
1023
+ s2[2 ] += y2[l+32 ] * (q2[l] >> 4 ); s2[3 ] += y2[l+32 ];
1007
1024
}
1008
- sumf += dall * (s[0 ] * sc[0 ] + s[2 ] * sc[2 ]) - dmin * (s[1 ] * sc[1 ] + s[3 ] * sc[3 ]);
1025
+ sumf += dall * (s1[0 ] * sc1[0 ] + s1[2 ] * sc1[1 ]
1026
+ + s2[0 ] * sc3[0 ] + s2[2 ] * sc3[1 ])
1027
+ - dmin * (s1[1 ] * sc2[0 ] + s1[3 ] * sc2[1 ]
1028
+ + s2[1 ] * sc4[0 ] + s2[3 ] * sc4[1 ]);
1009
1029
1010
1030
}
1011
1031
sum[ith] = sumf;
0 commit comments