@@ -971,7 +971,7 @@ kernel void kernel_mul_mat_q4_k_f32(
971
971
972
972
const uint16_t kmask1 = 0x3f3f ;
973
973
const uint16_t kmask2 = 0x0f0f ;
974
- const uint16_t kmask3 = 0x0303 ;
974
+ const uint16_t kmask3 = 0xc0c0 ;
975
975
976
976
const int nb = ne00/QK_K;
977
977
@@ -991,16 +991,15 @@ kernel void kernel_mul_mat_q4_k_f32(
991
991
992
992
const int im = il/2 ; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
993
993
const int in = il%2 ;
994
+ const int l0 = n*(2 *ir + in);
994
995
995
996
sum[ith] = 0 .0f ;
996
997
997
- // uchar2 sc1, sc2;
998
-
999
998
float sumf = 0 ;
1000
999
for (int i = tpitg.x ; i < nb; i += tptg.x ) {
1001
1000
1002
- device const uint8_t * q1 = (x + i)->qs + 32 *im + n*( 2 *ir + in) ;
1003
- device const float * y1 = yy + i*QK_K + 64 *im + n*( 2 *ir + in) ;
1001
+ device const uint8_t * q1 = (x + i)->qs + 32 *im + l0 ;
1002
+ device const float * y1 = yy + i*QK_K + 64 *im + l0 ;
1004
1003
device const uint8_t * q2 = q1 + 64 ;
1005
1004
device const float * y2 = y1 + 128 ;
1006
1005
@@ -1011,21 +1010,16 @@ kernel void kernel_mul_mat_q4_k_f32(
1011
1010
1012
1011
const uchar2 sc1 = as_type<uchar2>((uint16_t )(a[im+0 ] & kmask1));
1013
1012
const uchar2 sc2 = as_type<uchar2>((uint16_t )(a[im+2 ] & kmask1));
1014
- const uchar2 sc3 = as_type<uchar2>((uint16_t )(((a[im+4 ] >> 0 ) & kmask2) | ((( a[im+0 ] >> 6 ) & kmask3) << 4 )));
1015
- const uchar2 sc4 = as_type<uchar2>((uint16_t )(((a[im+4 ] >> 4 ) & kmask2) | ((( a[im+2 ] >> 6 ) & kmask3) << 4 )));
1013
+ const uchar2 sc3 = as_type<uchar2>((uint16_t )(((a[im+4 ] >> 0 ) & kmask2) | ((a[im+0 ] & kmask3) >> 2 )));
1014
+ const uchar2 sc4 = as_type<uchar2>((uint16_t )(((a[im+4 ] >> 4 ) & kmask2) | ((a[im+2 ] & kmask3) >> 2 )));
1016
1015
1017
- float4 s1 = {0 .f , 0 .f , 0 .f , 0 .f };
1018
- float4 s2 = {0 .f , 0 .f , 0 .f , 0 .f };
1016
+ float2 s = {0 .f , 0 .f };
1019
1017
for (int l = 0 ; l < n; ++l) {
1020
- s1[0 ] += y1[l+ 0 ] * (q1[l] & 0xF ); s1[1 ] += y1[l+ 0 ];
1021
- s1[2 ] += y1[l+32 ] * (q1[l] >> 4 ); s1[3 ] += y1[l+32 ];
1022
- s2[0 ] += y2[l+ 0 ] * (q2[l] & 0xF ); s2[1 ] += y2[l+ 0 ];
1023
- s2[2 ] += y2[l+32 ] * (q2[l] >> 4 ); s2[3 ] += y2[l+32 ];
1018
+ s[0 ] += y1[l] * sc1[0 ] * (q1[l] & 0xF ) + y1[l+32 ] * sc1[1 ] * (q1[l] >> 4 )
1019
+ + y2[l] * sc3[0 ] * (q2[l] & 0xF ) + y2[l+32 ] * sc3[1 ] * (q2[l] >> 4 );
1020
+ s[1 ] += y1[l] * sc2[0 ] + y1[l+32 ] * sc2[1 ] + y2[l] * sc4[0 ] + y2[l+32 ] * sc4[1 ];
1024
1021
}
1025
- sumf += dall * (s1[0 ] * sc1[0 ] + s1[2 ] * sc1[1 ]
1026
- + s2[0 ] * sc3[0 ] + s2[2 ] * sc3[1 ])
1027
- - dmin * (s1[1 ] * sc2[0 ] + s1[3 ] * sc2[1 ]
1028
- + s2[1 ] * sc4[0 ] + s2[3 ] * sc4[1 ]);
1022
+ sumf += dall * s[0 ] - dmin * s[1 ];
1029
1023
1030
1024
}
1031
1025
sum[ith] = sumf;
0 commit comments