Skip to content

Commit 74e0c9b

Browse files
committed
metal : fix accuracy of dequantization kernels
1 parent bc01448 commit 74e0c9b

File tree

2 files changed

+15
-14
lines changed

2 files changed

+15
-14
lines changed

ggml-metal.metal

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3521,10 +3521,10 @@ void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg
35213521

35223522
template <typename type4x4>
35233523
void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
3524-
const half d = xb->d;
3525-
const half min = xb->dmin;
3524+
const float d = xb->d;
3525+
const float min = xb->dmin;
35263526
device const uint8_t * q = (device const uint8_t *)xb->qs;
3527-
half dl, ml;
3527+
float dl, ml;
35283528
uint8_t sc = xb->scales[il];
35293529

35303530
#if QK_K == 256
@@ -3594,10 +3594,10 @@ void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg
35943594
q = q + (il/4) * 32 + 16 * (il&1);
35953595
il = il & 3;
35963596
const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
3597-
const half d = il < 2 ? xb->d : xb->d / 16.h;
3598-
const half min = xb->dmin;
3599-
const half dl = d * sc[0];
3600-
const half ml = min * sc[1];
3597+
const float d = il < 2 ? xb->d : xb->d / 16.h;
3598+
const float min = xb->dmin;
3599+
const float dl = d * sc[0];
3600+
const float ml = min * sc[1];
36013601
#else
36023602
q = q + 16 * (il&1);
36033603
device const uint8_t * s = xb->scales;
@@ -3624,13 +3624,13 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
36243624
uint8_t ul = 1 << (il/2);
36253625
il = il & 3;
36263626
const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
3627-
const half d = il < 2 ? xb->d : xb->d / 16.h;
3628-
const half min = xb->dmin;
3629-
const half dl = d * sc[0];
3630-
const half ml = min * sc[1];
3627+
const float d = il < 2 ? xb->d : xb->d / 16.h;
3628+
const float min = xb->dmin;
3629+
const float dl = d * sc[0];
3630+
const float ml = min * sc[1];
36313631

3632-
const ushort mask = il<2 ? 0x0F : 0xF0;
3633-
const half qh_val = il<2 ? 16.h : 256.h;
3632+
const ushort mask = il<2 ? 0x0F : 0xF0;
3633+
const float qh_val = il<2 ? 16.f : 256.f;
36343634
for (int i = 0; i < 16; ++i) {
36353635
reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
36363636
}

tests/test-backend-ops.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -432,9 +432,10 @@ struct test_case {
432432
if (err > ud->max_err) {
433433
printf("[%s] NMSE = %f ", ggml_op_desc(t1), err);
434434
//for (int i = 0; i < f1.size(); i++) {
435-
// printf("(%f, %f) ", f1[i], f2[i]);
435+
// printf("%5d %9.6f %9.6f, diff = %9.6f\n", i, f1[i], f2[i], f1[i] - f2[i]);
436436
//}
437437
//printf("\n");
438+
//exit(1);
438439
ud->ok = false;
439440
}
440441
return true;

0 commit comments

Comments
 (0)