File tree Expand file tree Collapse file tree 1 file changed +6
-6
lines changed Expand file tree Collapse file tree 1 file changed +6
-6
lines changed Original file line number Diff line number Diff line change @@ -1416,14 +1416,14 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1(
1416
1416
for (int i = 0 ; i < 4 ; ++i) {
1417
1417
const int sc = bq2_K->scales [iqs - iqs%8 + (iqs%8 ) / 4 + 2 *i];
1418
1418
1419
- const int vii = (vi >> (2 *i)) & 0x03030303 ;
1420
-
1421
1419
const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
1422
- const float d8 = bq8i->d ;
1423
- const int qs8 = *((int *) &bq8i->qs [4 *(iqs%8 )]);
1420
+ const float d8i = bq8i->d ;
1421
+
1422
+ const int vii = (vi >> (2 *i)) & 0x03030303 ;
1423
+ const int uii = *((int *) &bq8i->qs [4 * (iqs%8 )]);
1424
1424
1425
- sumf_d += d8 * __dp4a (vii, qs8 , 0 ) * (sc & 0xF );
1426
- sumf_m += d8 * __dp4a (0x01010101 , qs8 , 0 ) * (sc >> 4 );
1425
+ sumf_d += d8i * __dp4a (vii, uii , 0 ) * (sc & 0xF );
1426
+ sumf_m += d8i * __dp4a (0x01010101 , uii , 0 ) * (sc >> 4 );
1427
1427
}
1428
1428
1429
1429
You can’t perform that action at this time.
0 commit comments