@@ -3521,10 +3521,10 @@ void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg
3521
3521
3522
3522
template <typename type4x4>
3523
3523
void dequantize_q2_K (device const block_q2_K *xb, short il, thread type4x4 & reg) {
3524
- const half d = xb->d ;
3525
- const half min = xb->dmin ;
3524
+ const float d = xb->d ;
3525
+ const float min = xb->dmin ;
3526
3526
device const uint8_t * q = (device const uint8_t *)xb->qs ;
3527
- half dl, ml;
3527
+ float dl, ml;
3528
3528
uint8_t sc = xb->scales [il];
3529
3529
3530
3530
#if QK_K == 256
@@ -3594,10 +3594,10 @@ void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg
3594
3594
q = q + (il/4 ) * 32 + 16 * (il&1 );
3595
3595
il = il & 3 ;
3596
3596
const uchar2 sc = get_scale_min_k4_just2 (is, il/2 , xb->scales );
3597
- const half d = il < 2 ? xb->d : xb->d / 16 .h ;
3598
- const half min = xb->dmin ;
3599
- const half dl = d * sc[0 ];
3600
- const half ml = min * sc[1 ];
3597
+ const float d = il < 2 ? xb->d : xb->d / 16 .h ;
3598
+ const float min = xb->dmin ;
3599
+ const float dl = d * sc[0 ];
3600
+ const float ml = min * sc[1 ];
3601
3601
#else
3602
3602
q = q + 16 * (il&1 );
3603
3603
device const uint8_t * s = xb->scales ;
@@ -3624,13 +3624,13 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
3624
3624
uint8_t ul = 1 << (il/2 );
3625
3625
il = il & 3 ;
3626
3626
const uchar2 sc = get_scale_min_k4_just2 (is, il/2 , xb->scales );
3627
- const half d = il < 2 ? xb->d : xb->d / 16 .h ;
3628
- const half min = xb->dmin ;
3629
- const half dl = d * sc[0 ];
3630
- const half ml = min * sc[1 ];
3627
+ const float d = il < 2 ? xb->d : xb->d / 16 .h ;
3628
+ const float min = xb->dmin ;
3629
+ const float dl = d * sc[0 ];
3630
+ const float ml = min * sc[1 ];
3631
3631
3632
- const ushort mask = il<2 ? 0x0F : 0xF0 ;
3633
- const half qh_val = il<2 ? 16 .h : 256 .h ;
3632
+ const ushort mask = il<2 ? 0x0F : 0xF0 ;
3633
+ const float qh_val = il<2 ? 16 .f : 256 .f ;
3634
3634
for (int i = 0 ; i < 16 ; ++i) {
3635
3635
reg[i/4 ][i%4 ] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0 )) - ml;
3636
3636
}
0 commit comments