@@ -1446,26 +1446,24 @@ static __device__ __forceinline__ float vec_dot_q4_1_q8_1_mul_mat(
1446
1446
}
1447
1447
1448
1448
static __device__ __forceinline__ float vec_dot_q5_0_q8_1_impl (
1449
- const int & qs, const int & qh0, const int & qh1 , const int & ui0, const int & ui1, const half & d5, const half2 & ds8) {
1449
+ const int & qs, const int & qh , const int & ui0, const int & ui1, const half & d5, const half2 & ds8) {
1450
1450
1451
1451
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
1452
- int vi0 = (qs >> 0 ) & 0x0F0F0F0F ; // lower 4 qs bits, still need qh0 as 5th bits
1453
- vi0 |= (qh0 << 4 ) & 0x00000010 ; // 0 -> 4
1454
- vi0 |= (qh0 << 11 ) & 0x00001000 ; // 1 -> 12
1455
- vi0 |= (qh0 << 18 ) & 0x00100000 ; // 2 -> 20
1456
- vi0 |= (qh0 << 25 ) & 0x10000000 ; // 3 -> 28
1457
- vi0 = __vsub4 (vi0, 0x10101010 ); // subtract 16 from quantized values
1452
+ int vi0 = (qs >> 0 ) & 0x0F0F0F0F ; // lower 4 qs bits, still need qh as 5th bits
1453
+ vi0 |= (qh << 4 ) & 0x00000010 ; // 0 -> 4
1454
+ vi0 |= (qh << 11 ) & 0x00001000 ; // 1 -> 12
1455
+ vi0 |= (qh << 18 ) & 0x00100000 ; // 2 -> 20
1456
+ vi0 |= (qh << 25 ) & 0x10000000 ; // 3 -> 28
1458
1457
int sumi = __dp4a (vi0, ui0, 0 ); // SIMD dot product of quantized values
1459
1458
1460
- int vi1 = (qs >> 4 ) & 0x0F0F0F0F ; // upper 4 qs bits, still need qh1 as 5th bits
1461
- vi1 |= (qh1 << 4 ) & 0x00000010 ; // 0 -> 4
1462
- vi1 |= (qh1 << 11 ) & 0x00001000 ; // 1 -> 12
1463
- vi1 |= (qh1 << 18 ) & 0x00100000 ; // 2 -> 20
1464
- vi1 |= (qh1 << 25 ) & 0x10000000 ; // 3 -> 28
1465
- vi1 = __vsub4 (vi1, 0x10101010 ); // subtract 16 from quantized values
1459
+ int vi1 = (qs >> 4 ) & 0x0F0F0F0F ; // upper 4 qs bits, still need qh as 5th bits
1460
+ vi1 |= (qh >> 12 ) & 0x00000010 ; // 16 -> 4
1461
+ vi1 |= (qh >> 5 ) & 0x00001000 ; // 17 -> 12
1462
+ vi1 |= (qh << 2 ) & 0x00100000 ; // 18 -> 20
1463
+ vi1 |= (qh << 9 ) & 0x10000000 ; // 19 -> 28
1466
1464
sumi = __dp4a (vi1, ui1, sumi); // SIMD dot product of quantized values
1467
1465
1468
- return sumi* __half2float (d5)* __half2float (ds8.x );
1466
+ return __half2float (d5) * (sumi* __half2float (ds8.x ) - ( 16 /QI5_0) * __half2float (ds8. y ) );
1469
1467
#else
1470
1468
return 0 .0f ; // only to satisfy the compiler
1471
1469
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
@@ -1477,12 +1475,11 @@ static __device__ __forceinline__ float vec_dot_q5_0_q8_1(
1477
1475
const block_q5_0 * bq5_0 = (const block_q5_0 *) vbq;
1478
1476
1479
1477
const int qs = get_int_from_uint8 (bq5_0->qs , iqs);
1480
- const int qh0 = bq5_0->qh [iqs/2 + 0 ] >> 4 *(iqs%2 );
1481
- const int qh1 = bq5_0->qh [iqs/2 + 2 ] >> 4 *(iqs%2 );
1478
+ const int qh = get_int_from_uint8 (bq5_0->qh , 0 ) >> (4 * iqs);
1482
1479
const int ui0 = get_int_from_int8_aligned (bq8_1->qs , iqs);
1483
1480
const int ui1 = get_int_from_int8_aligned (bq8_1->qs , iqs + QI5_0);
1484
1481
1485
- return vec_dot_q5_0_q8_1_impl (qs, qh0, qh1 , ui0, ui1, bq5_0->d , bq8_1->ds );
1482
+ return vec_dot_q5_0_q8_1_impl (qs, qh , ui0, ui1, bq5_0->d , bq8_1->ds );
1486
1483
}
1487
1484
1488
1485
static __device__ __forceinline__ float vec_dot_q5_1_q8_1 (
0 commit comments