12
12
#define MULTILINE_QUOTE (...) #__VA_ARGS__
13
13
const char * clblast_dequant = MULTILINE_QUOTE (
14
14
15
+ typedef uchar uint8_t ;
16
+ typedef int int32_t ;
17
+ typedef uint uint32_t ;
18
+
19
+ constant uint QK4_0 = 32 ;
15
20
struct block_q4_0
16
21
{
17
22
float d ;
18
- uchar qs [16 ];
23
+ uint8_t qs [QK4_0 / 2 ];
19
24
};
20
25
21
- __kernel void dequantize_row_q4_0 (__global struct block_q4_0 * blocks , __global float * result ) {
22
- const uint i = get_global_id (0 ) / 32 ;
23
- const uint l = get_local_id (0 );
24
-
25
- const float d = blocks [i ].d ;
26
-
27
- const uchar vi = blocks [i ].qs [l ];
28
-
29
- const uint index = i * 32 + l * 2 ;
30
- result [index + 0 ] = ((vi & 0xf ) - 8 )* d ;
31
- result [index + 1 ] = ((vi >> 4 ) - 8 )* d ;
32
- }
33
-
26
+ constant uint QK4_1 = 32 ;
34
27
struct block_q4_1
35
28
{
36
29
float d ;
37
30
float m ;
38
- uchar qs [16 ];
31
+ uint8_t qs [QK4_1 / 2 ];
39
32
};
40
33
41
- __kernel void dequantize_row_q4_1 (__global struct block_q4_1 * blocks , __global float * result ) {
42
- const uint i = get_global_id (0 ) / 32 ;
43
- const uint l = get_local_id (0 );
44
-
45
- const float d = blocks [i ].d ;
46
- const float m = blocks [i ].m ;
47
-
48
- const uchar vi = blocks [i ].qs [l ];
34
+ constant uint QK5_0 = 32 ;
35
+ struct __attribute__ ((packed )) block_q5_0
36
+ {
37
+ half d ;
38
+ uint32_t qh ;
39
+ uint8_t qs [QK5_0 / 2 ];
40
+ };
49
41
50
- const uint index = i * 32 + l * 2 ;
51
- result [index + 0 ] = (vi & 0xf ) * d + m ;
52
- result [index + 1 ] = (vi >> 4 ) * d + m ;
53
- }
42
+ constant uint QK5_1 = 32 ;
43
+ struct block_q5_1
44
+ {
45
+ half d ;
46
+ half m ;
47
+ uint32_t qh ;
48
+ uint8_t qs [QK5_1 / 2 ];
49
+ };
54
50
55
- struct block_q5_0
51
+ constant uint QK8_0 = 32 ;
52
+ struct block_q8_0
56
53
{
57
54
float d ;
58
- uint qh ;
59
- uchar qs [16 ];
55
+ uint8_t qs [QK8_0 ];
60
56
};
61
57
62
- __kernel void dequantize_row_q5_0 (__global struct block_q5_0 * blocks , __global float * result ) {
63
- const uint i = get_global_id (0 ) / 32 ;
64
- const uint l = get_local_id (0 );
65
58
66
- const float d = blocks [i ].d ;
59
+ __kernel void dequantize_row_q4_0 (__global struct block_q4_0 * x , __global float * y ) {
60
+ constant uint qk = QK4_0 ;
67
61
68
- const uchar vi = blocks [i ].qs [l ];
62
+ const uint i = get_global_id (0 ) / qk ;
63
+ const uint j = get_local_id (0 );
69
64
70
- const uint l2 = l * 2 ;
65
+ const float d = x [ i ]. d ;
71
66
72
- const uchar vh0 = (( blocks [i ].qh & ( 1 << ( l2 + 0 ))) >> ( l2 + 0 )) << 4 ;
73
- const uchar vh1 = (( blocks [i ].qh & ( 1 << ( l2 + 1 ))) >> ( l2 + 1 )) << 4 ;
67
+ const int x0 = (x [i ].qs [ j ] & 0xf ) - 8 ;
68
+ const int x1 = (x [i ].qs [ j ] >> 4 ) - 8 ;
74
69
75
- const uint index = i * 32 + l2 ;
76
- result [index + 0 ] = (((vi & 0xf ) | vh0 ) - 16 )* d ;
77
- result [index + 1 ] = (((vi >> 4 ) | vh1 ) - 16 )* d ;
70
+ y [i * qk + j + 0 ] = x0 * d ;
71
+ y [i * qk + j + qk /2 ] = x1 * d ;
78
72
}
79
73
80
- struct block_q5_1
81
- {
82
- ushort d ;
83
- ushort m ;
84
- uint qh ;
85
- uchar qs [16 ];
86
- };
74
+ __kernel void dequantize_row_q4_1 (__global struct block_q4_1 * x , __global float * y ) {
75
+ constant uint qk = QK4_1 ;
87
76
88
- __kernel void dequantize_row_q5_1 (__global struct block_q5_1 * blocks , __global float * result ) {
89
- const uint i = get_global_id (0 ) / 32 ;
90
- const uint l = get_local_id (0 );
77
+ const uint i = get_global_id (0 ) / qk ;
78
+ const uint j = get_local_id (0 );
91
79
92
- const float d = vload_half ( 0 , ( __global half * ) & blocks [i ].d ) ;
93
- const float m = vload_half ( 0 , ( __global half * ) & blocks [i ].m ) ;
80
+ const float d = x [i ].d ;
81
+ const float m = x [i ].m ;
94
82
95
- const uchar vi = blocks [i ].qs [l ];
83
+ const int x0 = (x [i ].qs [j ] & 0xf );
84
+ const int x1 = (x [i ].qs [j ] >> 4 );
85
+
86
+ y [i * qk + j + 0 ] = x0 * d + m ;
87
+ y [i * qk + j + qk /2 ] = x1 * d + m ;
88
+ }
96
89
97
- const uint l2 = l * 2 ;
90
+ __kernel void dequantize_row_q5_0 (__global struct block_q5_0 * x , __global float * y ) {
91
+ constant uint qk = QK5_0 ;
98
92
99
- const uchar vh0 = (( blocks [ i ]. qh & ( 1 << ( l2 + 0 ))) >> ( l2 + 0 )) << 4 ;
100
- const uchar vh1 = (( blocks [ i ]. qh & ( 1 << ( l2 + 1 ))) >> ( l2 + 1 )) << 4 ;
93
+ const uint i = get_global_id ( 0 ) / qk ;
94
+ const uint j = get_local_id ( 0 ) ;
101
95
102
- const uint index = i * 32 + l2 ;
103
- result [index + 0 ] = ((vi & 0xf ) | vh0 )* d + m ;
104
- result [index + 1 ] = ((vi >> 4 ) | vh1 )* d + m ;
96
+ const float d = vload_half (0 , (__global half * ) & x [i ].d );
97
+
98
+ uint32_t qh = x [i ].qh ;
99
+
100
+ const uint8_t xh_0 = ((qh >> (j + 0 )) << 4 ) & 0x10 ;
101
+ const uint8_t xh_1 = ((qh >> (j + 12 )) ) & 0x10 ;
102
+
103
+ const int32_t x0 = ((x [i ].qs [j ] & 0xf ) | xh_0 ) - 16 ;
104
+ const int32_t x1 = ((x [i ].qs [j ] >> 4 ) | xh_1 ) - 16 ;
105
+
106
+ y [i * qk + j + 0 ] = x0 * d ;
107
+ y [i * qk + j + qk /2 ] = x1 * d ;
105
108
}
106
109
107
- struct block_q8_0
108
- {
109
- float d ;
110
- char qs [32 ];
111
- };
110
+ __kernel void dequantize_row_q5_1 (__global struct block_q5_1 * x , __global float * y ) {
111
+ constant uint qk = QK5_1 ;
112
+
113
+ const uint i = get_global_id (0 ) / qk ;
114
+ const uint j = get_local_id (0 );
115
+
116
+ const float d = vload_half (0 , (__global half * ) & x [i ].d );
117
+ const float m = vload_half (0 , (__global half * ) & x [i ].m );
112
118
113
- __kernel void dequantize_row_q8_0 (__global struct block_q8_0 * blocks , __global float * result ) {
114
- const uint i = get_global_id (0 ) / 32 ;
115
- const uint l = get_local_id (0 );
119
+ uint32_t qh = x [i ].qh ;
116
120
117
- result [i * 32 + l ] = blocks [i ].qs [l ] * blocks [i ].d ;
121
+ const uint8_t xh_0 = ((qh >> (j + 0 )) << 4 ) & 0x10 ;
122
+ const uint8_t xh_1 = ((qh >> (j + 12 )) ) & 0x10 ;
123
+
124
+ const int x0 = (x [i ].qs [j ] & 0xf ) | xh_0 ;
125
+ const int x1 = (x [i ].qs [j ] >> 4 ) | xh_1 ;
126
+
127
+ y [i * qk + j + 0 ] = x0 * d + m ;
128
+ y [i * qk + j + qk /2 ] = x1 * d + m ;
129
+ }
130
+
131
+ __kernel void dequantize_row_q8_0 (__global struct block_q8_0 * x , __global float * y ) {
132
+ constant uint qk = QK8_0 ;
133
+ const uint i = get_global_id (0 ) / qk ;
134
+ const uint j = get_local_id (0 );
135
+
136
+ const float d = x [i ].d ;
137
+ y [i * qk + j ] = x [i ].qs [j ]* d ;
118
138
}
119
139
120
140
);
@@ -128,20 +148,6 @@ __kernel void dequantize_row_q8_0(__global struct block_q8_0* blocks, __global f
128
148
} \
129
149
} while (0)
130
150
131
- #define QK5_0 32
132
- typedef struct {
133
- ggml_fp16_t d ; // delta
134
- uint8_t qh [4 ]; // 5-th bit of quants
135
- uint8_t qs [QK5_0 / 2 ]; // nibbles / quants
136
- } block_q5_0 ;
137
-
138
-
139
- typedef struct {
140
- float d ; // delta
141
- uint32_t qh ; // 5-th bit of quants
142
- uint8_t qs [QK5_0 / 2 ]; // nibbles / quants
143
- } cl_block_q5_0 ;
144
-
145
151
static cl_platform_id platform ;
146
152
static cl_device_id device ;
147
153
static cl_context context ;
@@ -252,7 +258,6 @@ void ggml_cl_sgemm_wrapper(
252
258
cl_kernel kernel ;
253
259
size_t global = n * k , local , size_qb ;
254
260
bool dequant ;
255
- cl_block_q5_0 * cl_host_b ;
256
261
257
262
switch (btype ) {
258
263
case GGML_TYPE_F32 :
@@ -274,18 +279,7 @@ void ggml_cl_sgemm_wrapper(
274
279
dequant = true;
275
280
kernel = kernel_q5_0 ;
276
281
local = 16 ;
277
- // For some reason OpenCL seems to be incapable of working with structs of size 22.
278
- // 20 and 24 bytes are fine. Workaround to do the fp16 to fp32 step on CPU...
279
- // TODO Find the reason, fix and remove workaround.
280
- const block_q5_0 * b = (const block_q5_0 * ) host_b ;
281
- cl_host_b = (cl_block_q5_0 * ) malloc (sizeof (cl_block_q5_0 ) * global / 32 );
282
- for (size_t i = 0 ; i < global / 32 ; i ++ ) {
283
- cl_host_b [i ].d = ggml_fp16_to_fp32 (b [i ].d );
284
- memcpy (& cl_host_b [i ].qh , b [i ].qh , sizeof (uint32_t ));
285
- memcpy (& cl_host_b [i ].qs , b [i ].qs , QK5_0 / 2 );
286
- }
287
- host_b = (const float * ) cl_host_b ;
288
- size_qb = global * (sizeof (float ) + sizeof (uint32_t ) + local ) / 32 ;
282
+ size_qb = global * (sizeof (ggml_fp16_t ) + sizeof (uint32_t ) + local ) / 32 ;
289
283
break ;
290
284
case GGML_TYPE_Q5_1 :
291
285
dequant = true;
@@ -364,7 +358,4 @@ void ggml_cl_sgemm_wrapper(
364
358
clWaitForEvents (1 , & ev_c );
365
359
clReleaseEvent (ev_sgemm );
366
360
clReleaseEvent (ev_c );
367
- if (btype == GGML_TYPE_Q5_0 ) {
368
- free ((void * ) cl_host_b );
369
- }
370
361
}
0 commit comments