Skip to content

Commit 699b1ad

Browse files
authored
opencl : fix kernels for the new formats (#1422)
* Fix OpenCL kernels for the new formats * Fix Q5_0 alignment issues.
1 parent fb62f92 commit 699b1ad

File tree

1 file changed

+90
-99
lines changed

1 file changed

+90
-99
lines changed

ggml-opencl.c

Lines changed: 90 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -12,109 +12,129 @@
1212
#define MULTILINE_QUOTE(...) #__VA_ARGS__
1313
const char * clblast_dequant = MULTILINE_QUOTE(
1414

15+
typedef uchar uint8_t;
16+
typedef int int32_t;
17+
typedef uint uint32_t;
18+
19+
constant uint QK4_0 = 32;
1520
struct block_q4_0
1621
{
1722
float d;
18-
uchar qs[16];
23+
uint8_t qs[QK4_0 / 2];
1924
};
2025

21-
__kernel void dequantize_row_q4_0(__global struct block_q4_0* blocks, __global float* result) {
22-
const uint i = get_global_id(0) / 32;
23-
const uint l = get_local_id(0);
24-
25-
const float d = blocks[i].d;
26-
27-
const uchar vi = blocks[i].qs[l];
28-
29-
const uint index = i*32 + l*2;
30-
result[index + 0] = ((vi & 0xf) - 8)*d;
31-
result[index + 1] = ((vi >> 4) - 8)*d;
32-
}
33-
26+
constant uint QK4_1 = 32;
3427
struct block_q4_1
3528
{
3629
float d;
3730
float m;
38-
uchar qs[16];
31+
uint8_t qs[QK4_1 / 2];
3932
};
4033

41-
__kernel void dequantize_row_q4_1(__global struct block_q4_1* blocks, __global float* result) {
42-
const uint i = get_global_id(0) / 32;
43-
const uint l = get_local_id(0);
44-
45-
const float d = blocks[i].d;
46-
const float m = blocks[i].m;
47-
48-
const uchar vi = blocks[i].qs[l];
34+
constant uint QK5_0 = 32;
35+
struct __attribute__ ((packed)) block_q5_0
36+
{
37+
half d;
38+
uint32_t qh;
39+
uint8_t qs[QK5_0 / 2];
40+
};
4941

50-
const uint index = i*32 + l*2;
51-
result[index + 0] = (vi & 0xf) * d + m;
52-
result[index + 1] = (vi >> 4) * d + m;
53-
}
42+
constant uint QK5_1 = 32;
43+
struct block_q5_1
44+
{
45+
half d;
46+
half m;
47+
uint32_t qh;
48+
uint8_t qs[QK5_1 / 2];
49+
};
5450

55-
struct block_q5_0
51+
constant uint QK8_0 = 32;
52+
struct block_q8_0
5653
{
5754
float d;
58-
uint qh;
59-
uchar qs[16];
55+
uint8_t qs[QK8_0];
6056
};
6157

62-
__kernel void dequantize_row_q5_0(__global struct block_q5_0* blocks, __global float* result) {
63-
const uint i = get_global_id(0) / 32;
64-
const uint l = get_local_id(0);
6558

66-
const float d = blocks[i].d;
59+
__kernel void dequantize_row_q4_0(__global struct block_q4_0* x, __global float* y) {
60+
constant uint qk = QK4_0;
6761

68-
const uchar vi = blocks[i].qs[l];
62+
const uint i = get_global_id(0) / qk;
63+
const uint j = get_local_id(0);
6964

70-
const uint l2 = l * 2;
65+
const float d = x[i].d;
7166

72-
const uchar vh0 = ((blocks[i].qh & (1 << (l2 + 0))) >> (l2 + 0)) << 4;
73-
const uchar vh1 = ((blocks[i].qh & (1 << (l2 + 1))) >> (l2 + 1)) << 4;
67+
const int x0 = (x[i].qs[j] & 0xf) - 8;
68+
const int x1 = (x[i].qs[j] >> 4) - 8;
7469

75-
const uint index = i*32 + l2;
76-
result[index + 0] = (((vi & 0xf) | vh0) - 16)*d;
77-
result[index + 1] = (((vi >> 4) | vh1) - 16)*d;
70+
y[i*qk + j + 0 ] = x0*d;
71+
y[i*qk + j + qk/2] = x1*d;
7872
}
7973

80-
struct block_q5_1
81-
{
82-
ushort d;
83-
ushort m;
84-
uint qh;
85-
uchar qs[16];
86-
};
74+
__kernel void dequantize_row_q4_1(__global struct block_q4_1* x, __global float* y) {
75+
constant uint qk = QK4_1;
8776

88-
__kernel void dequantize_row_q5_1(__global struct block_q5_1* blocks, __global float* result) {
89-
const uint i = get_global_id(0) / 32;
90-
const uint l = get_local_id(0);
77+
const uint i = get_global_id(0) / qk;
78+
const uint j = get_local_id(0);
9179

92-
const float d = vload_half(0, (__global half*) &blocks[i].d);
93-
const float m = vload_half(0, (__global half*) &blocks[i].m);
80+
const float d = x[i].d;
81+
const float m = x[i].m;
9482

95-
const uchar vi = blocks[i].qs[l];
83+
const int x0 = (x[i].qs[j] & 0xf);
84+
const int x1 = (x[i].qs[j] >> 4);
85+
86+
y[i*qk + j + 0 ] = x0*d + m;
87+
y[i*qk + j + qk/2] = x1*d + m;
88+
}
9689

97-
const uint l2 = l * 2;
90+
__kernel void dequantize_row_q5_0(__global struct block_q5_0* x, __global float* y) {
91+
constant uint qk = QK5_0;
9892

99-
const uchar vh0 = ((blocks[i].qh & (1 << (l2 + 0))) >> (l2 + 0)) << 4;
100-
const uchar vh1 = ((blocks[i].qh & (1 << (l2 + 1))) >> (l2 + 1)) << 4;
93+
const uint i = get_global_id(0) / qk;
94+
const uint j = get_local_id(0);
10195

102-
const uint index = i*32 + l2;
103-
result[index + 0] = ((vi & 0xf) | vh0)*d + m;
104-
result[index + 1] = ((vi >> 4) | vh1)*d + m;
96+
const float d = vload_half(0, (__global half*) &x[i].d);
97+
98+
uint32_t qh = x[i].qh;
99+
100+
const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10;
101+
const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10;
102+
103+
const int32_t x0 = ((x[i].qs[j] & 0xf) | xh_0) - 16;
104+
const int32_t x1 = ((x[i].qs[j] >> 4) | xh_1) - 16;
105+
106+
y[i*qk + j + 0 ] = x0*d;
107+
y[i*qk + j + qk/2] = x1*d;
105108
}
106109

107-
struct block_q8_0
108-
{
109-
float d;
110-
char qs[32];
111-
};
110+
__kernel void dequantize_row_q5_1(__global struct block_q5_1* x, __global float* y) {
111+
constant uint qk = QK5_1;
112+
113+
const uint i = get_global_id(0) / qk;
114+
const uint j = get_local_id(0);
115+
116+
const float d = vload_half(0, (__global half*) &x[i].d);
117+
const float m = vload_half(0, (__global half*) &x[i].m);
112118

113-
__kernel void dequantize_row_q8_0(__global struct block_q8_0* blocks, __global float* result) {
114-
const uint i = get_global_id(0) / 32;
115-
const uint l = get_local_id(0);
119+
uint32_t qh = x[i].qh;
116120

117-
result[i*32 + l] = blocks[i].qs[l] * blocks[i].d;
121+
const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10;
122+
const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10;
123+
124+
const int x0 = (x[i].qs[j] & 0xf) | xh_0;
125+
const int x1 = (x[i].qs[j] >> 4) | xh_1;
126+
127+
y[i*qk + j + 0 ] = x0*d + m;
128+
y[i*qk + j + qk/2] = x1*d + m;
129+
}
130+
131+
__kernel void dequantize_row_q8_0(__global struct block_q8_0* x, __global float* y) {
132+
constant uint qk = QK8_0;
133+
const uint i = get_global_id(0) / qk;
134+
const uint j = get_local_id(0);
135+
136+
const float d = x[i].d;
137+
y[i*qk + j] = x[i].qs[j]*d;
118138
}
119139

120140
);
@@ -128,20 +148,6 @@ __kernel void dequantize_row_q8_0(__global struct block_q8_0* blocks, __global f
128148
} \
129149
} while (0)
130150

131-
#define QK5_0 32
132-
typedef struct {
133-
ggml_fp16_t d; // delta
134-
uint8_t qh[4]; // 5-th bit of quants
135-
uint8_t qs[QK5_0 / 2]; // nibbles / quants
136-
} block_q5_0;
137-
138-
139-
typedef struct {
140-
float d; // delta
141-
uint32_t qh; // 5-th bit of quants
142-
uint8_t qs[QK5_0 / 2]; // nibbles / quants
143-
} cl_block_q5_0;
144-
145151
static cl_platform_id platform;
146152
static cl_device_id device;
147153
static cl_context context;
@@ -252,7 +258,6 @@ void ggml_cl_sgemm_wrapper(
252258
cl_kernel kernel;
253259
size_t global = n * k, local, size_qb;
254260
bool dequant;
255-
cl_block_q5_0* cl_host_b;
256261

257262
switch (btype) {
258263
case GGML_TYPE_F32:
@@ -274,18 +279,7 @@ void ggml_cl_sgemm_wrapper(
274279
dequant = true;
275280
kernel = kernel_q5_0;
276281
local = 16;
277-
// For some reason OpenCL seems to be incapable of working with structs of size 22.
278-
// 20 and 24 bytes are fine. Workaround to do the fp16 to fp32 step on CPU...
279-
// TODO Find the reason, fix and remove workaround.
280-
const block_q5_0* b = (const block_q5_0*) host_b;
281-
cl_host_b = (cl_block_q5_0*) malloc(sizeof(cl_block_q5_0) * global / 32);
282-
for (size_t i = 0; i < global / 32; i++) {
283-
cl_host_b[i].d = ggml_fp16_to_fp32(b[i].d);
284-
memcpy(&cl_host_b[i].qh, b[i].qh, sizeof(uint32_t));
285-
memcpy(&cl_host_b[i].qs, b[i].qs, QK5_0 / 2);
286-
}
287-
host_b = (const float*) cl_host_b;
288-
size_qb = global * (sizeof(float) + sizeof(uint32_t) + local) / 32;
282+
size_qb = global * (sizeof(ggml_fp16_t) + sizeof(uint32_t) + local) / 32;
289283
break;
290284
case GGML_TYPE_Q5_1:
291285
dequant = true;
@@ -364,7 +358,4 @@ void ggml_cl_sgemm_wrapper(
364358
clWaitForEvents(1, &ev_c);
365359
clReleaseEvent(ev_sgemm);
366360
clReleaseEvent(ev_c);
367-
if (btype == GGML_TYPE_Q5_0) {
368-
free((void*) cl_host_b);
369-
}
370361
}

0 commit comments

Comments
 (0)