Skip to content

Commit 4dccb38

Browse files
authored
metal : improve dequantize precision to match CPU (#4836)
ggml-ci
1 parent 9a818f7 commit 4dccb38

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

ggml-metal.metal

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3841,8 +3841,8 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg
38413841
uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4];
38423842
int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2)
38433843
: (scale_2&kmask2) | ((scale_1&kmask1) << 4);
3844-
half dl = il<8 ? d_all * (dl_int - 32.h) : d_all * (dl_int / 16.h - 32.h);
3845-
const half ml = 4.h * dl;
3844+
float dl = il<8 ? d_all * (dl_int - 32.f) : d_all * (dl_int / 16.f - 32.f);
3845+
const float ml = 4.f * dl;
38463846

38473847
il = (il/2) & 3;
38483848
const half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
@@ -3909,7 +3909,7 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
39093909
uint8_t ul = 1 << (il/2);
39103910
il = il & 3;
39113911
const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
3912-
const float d = il < 2 ? xb->d : xb->d / 16.h;
3912+
const float d = il < 2 ? xb->d : xb->d / 16.f;
39133913
const float min = xb->dmin;
39143914
const float dl = d * sc[0];
39153915
const float ml = min * sc[1];
@@ -3942,17 +3942,17 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg
39423942
#if QK_K == 256
39433943
ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
39443944
qh = qh + 32*(il/8) + 16*(il&1);
3945-
half sc = scales[(il%2) + 2 * ((il/2))];
3945+
float sc = scales[(il%2) + 2 * ((il/2))];
39463946
il = (il/2) & 3;
39473947
#else
39483948
ql = ql + 16 * (il&1);
3949-
half sc = scales[il];
3949+
float sc = scales[il];
39503950
#endif
39513951
const uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
39523952
const uint16_t kmask2 = il>1 ? 0xF0 : 0x0F;
3953-
const half coef = il>1 ? 1.f/16.h : 1.h;
3954-
const half ml = d_all * sc * 32.h;
3955-
const half dl = d_all * sc * coef;
3953+
const float coef = il>1 ? 1.f/16.f : 1.f;
3954+
const float ml = d_all * sc * 32.f;
3955+
const float dl = d_all * sc * coef;
39563956
for (int i = 0; i < 16; ++i) {
39573957
const half q = il&1 ? ((ql[i] & kmask2) | ((qh[i] & kmask1) << 2))
39583958
: ((ql[i] & kmask2) | ((qh[i] & kmask1) << 4));

0 commit comments

Comments
 (0)