3
3
#define CL_TARGET_OPENCL_VERSION 110
4
4
#include <clblast_c.h>
5
5
6
+ #include <stdlib.h>
6
7
#include <stdio.h>
7
8
#include <string.h>
8
9
9
10
#include "ggml.h"
10
11
11
- #include "ggml-opencl-dequant.cl"
12
+ #define MULTILINE_QUOTE (...) #__VA_ARGS__
13
+ const char * clblast_dequant = MULTILINE_QUOTE (
14
+
15
+ struct block_q4_0
16
+ {
17
+ float d ;
18
+ uchar qs [16 ];
19
+ };
20
+
21
+ __kernel void dequantize_row_q4_0 (__global struct block_q4_0 * blocks , __global float * result ) {
22
+ const uint i = get_global_id (0 ) / 32 ;
23
+ const uint l = get_local_id (0 );
24
+
25
+ const float d = blocks [i ].d ;
26
+
27
+ const uchar vi = blocks [i ].qs [l ];
28
+
29
+ const uint index = i * 32 + l * 2 ;
30
+ result [index + 0 ] = ((vi & 0xf ) - 8 )* d ;
31
+ result [index + 1 ] = ((vi >> 4 ) - 8 )* d ;
32
+ }
33
+
34
+ struct block_q4_1
35
+ {
36
+ float d ;
37
+ float m ;
38
+ uchar qs [16 ];
39
+ };
40
+
41
+ __kernel void dequantize_row_q4_1 (__global struct block_q4_1 * blocks , __global float * result ) {
42
+ const uint i = get_global_id (0 ) / 32 ;
43
+ const uint l = get_local_id (0 );
44
+
45
+ const float d = blocks [i ].d ;
46
+ const float m = blocks [i ].m ;
47
+
48
+ const uchar vi = blocks [i ].qs [l ];
49
+
50
+ const uint index = i * 32 + l * 2 ;
51
+ result [index + 0 ] = (vi & 0xf ) * d + m ;
52
+ result [index + 1 ] = (vi >> 4 ) * d + m ;
53
+ }
54
+
55
+ struct block_q4_2
56
+ {
57
+ ushort d ;
58
+ uchar qs [8 ];
59
+ };
60
+
61
+ __kernel void dequantize_row_q4_2 (__global struct block_q4_2 * blocks , __global float * result ) {
62
+ const uint i = get_global_id (0 ) / 16 ;
63
+ const uint l = get_local_id (0 );
64
+
65
+ const float d = vload_half (0 , (__global half * ) & blocks [i ].d );
66
+
67
+ const uchar vi = blocks [i ].qs [l ];
68
+
69
+ const uint index = i * 16 + l * 2 ;
70
+ result [index + 0 ] = ((vi & 0xf ) - 8 )* d ;
71
+ result [index + 1 ] = ((vi >> 4 ) - 8 )* d ;
72
+ }
73
+
74
+
75
+ struct block_q5_0
76
+ {
77
+ float d ;
78
+ uint qh ;
79
+ uchar qs [16 ];
80
+ };
81
+
82
+ __kernel void dequantize_row_q5_0 (__global struct block_q5_0 * blocks , __global float * result ) {
83
+ const uint i = get_global_id (0 ) / 32 ;
84
+ const uint l = get_local_id (0 );
85
+
86
+ const float d = blocks [i ].d ;
87
+
88
+ const uchar vi = blocks [i ].qs [l ];
89
+
90
+ const uint l2 = l * 2 ;
91
+
92
+ const uchar vh0 = ((blocks [i ].qh & (1 << (l2 + 0 ))) >> (l2 + 0 )) << 4 ;
93
+ const uchar vh1 = ((blocks [i ].qh & (1 << (l2 + 1 ))) >> (l2 + 1 )) << 4 ;
94
+
95
+ const uint index = i * 32 + l2 ;
96
+ result [index + 0 ] = (((vi & 0xf ) | vh0 ) - 16 )* d ;
97
+ result [index + 1 ] = (((vi >> 4 ) | vh1 ) - 16 )* d ;
98
+ }
99
+
100
+ struct block_q5_1
101
+ {
102
+ ushort d ;
103
+ ushort m ;
104
+ uint qh ;
105
+ uchar qs [16 ];
106
+ };
107
+
108
+ __kernel void dequantize_row_q5_1 (__global struct block_q5_1 * blocks , __global float * result ) {
109
+ const uint i = get_global_id (0 ) / 32 ;
110
+ const uint l = get_local_id (0 );
111
+
112
+ const float d = vload_half (0 , (__global half * ) & blocks [i ].d );
113
+ const float m = vload_half (0 , (__global half * ) & blocks [i ].m );
114
+
115
+ const uchar vi = blocks [i ].qs [l ];
116
+
117
+ const uint l2 = l * 2 ;
118
+
119
+ const uchar vh0 = ((blocks [i ].qh & (1 << (l2 + 0 ))) >> (l2 + 0 )) << 4 ;
120
+ const uchar vh1 = ((blocks [i ].qh & (1 << (l2 + 1 ))) >> (l2 + 1 )) << 4 ;
121
+
122
+ const uint index = i * 32 + l2 ;
123
+ result [index + 0 ] = ((vi & 0xf ) | vh0 )* d + m ;
124
+ result [index + 1 ] = ((vi >> 4 ) | vh1 )* d + m ;
125
+ }
126
+
127
+ struct block_q8_0
128
+ {
129
+ float d ;
130
+ char qs [32 ];
131
+ };
132
+
133
+ __kernel void dequantize_row_q8_0 (__global struct block_q8_0 * blocks , __global float * result ) {
134
+ const uint i = get_global_id (0 ) / 32 ;
135
+ const uint l = get_local_id (0 );
136
+
137
+ result [i * 32 + l ] = blocks [i ].qs [l ] * blocks [i ].d ;
138
+ }
139
+
140
+ );
12
141
13
142
#define CL_CHECK (err , name ) \
14
143
do { \
19
148
} \
20
149
} while (0)
21
150
151
+ #define QK5_0 32
152
+ typedef struct {
153
+ ggml_fp16_t d ; // delta
154
+ uint8_t qh [4 ]; // 5-th bit of quants
155
+ uint8_t qs [QK5_0 / 2 ]; // nibbles / quants
156
+ } block_q5_0 ;
157
+
158
+
159
+ typedef struct {
160
+ float d ; // delta
161
+ uint32_t qh ; // 5-th bit of quants
162
+ uint8_t qs [QK5_0 / 2 ]; // nibbles / quants
163
+ } cl_block_q5_0 ;
164
+
22
165
static cl_platform_id platform ;
23
166
static cl_device_id device ;
24
167
static cl_context context ;
25
168
static cl_command_queue queue ;
26
169
static cl_program program ;
27
- static cl_kernel kernel_q4_0 , kernel_q4_1 , kernel_q4_2 ;
170
+ static cl_kernel kernel_q4_0 , kernel_q4_1 , kernel_q4_2 , kernel_q5_0 , kernel_q5_1 , kernel_q8_0 ;
28
171
static cl_mem cl_buffer_a , cl_buffer_qb , cl_buffer_b , cl_buffer_c ;
29
172
static size_t cl_size_a = 0 , cl_size_qb = 0 , cl_size_b = 0 , cl_size_c = 0 ;
30
173
@@ -97,6 +240,12 @@ void ggml_cl_init(void) {
97
240
CL_CHECK (err , "clCreateKernel" );
98
241
kernel_q4_2 = clCreateKernel (program , "dequantize_row_q4_2" , & err );
99
242
CL_CHECK (err , "clCreateKernel" );
243
+ kernel_q5_0 = clCreateKernel (program , "dequantize_row_q5_0" , & err );
244
+ CL_CHECK (err , "clCreateKernel" );
245
+ kernel_q5_1 = clCreateKernel (program , "dequantize_row_q5_1" , & err );
246
+ CL_CHECK (err , "clCreateKernel" );
247
+ kernel_q8_0 = clCreateKernel (program , "dequantize_row_q8_0" , & err );
248
+ CL_CHECK (err , "clCreateKernel" );
100
249
}
101
250
102
251
static void ggml_cl_malloc (size_t req_size , size_t * cur_size , cl_mem_flags flags , cl_mem * buf ) {
@@ -125,6 +274,7 @@ void ggml_cl_sgemm_wrapper(
125
274
cl_kernel kernel ;
126
275
size_t global = n * k , local , size_qb ;
127
276
bool dequant ;
277
+ cl_block_q5_0 * cl_host_b ;
128
278
129
279
switch (btype ) {
130
280
case GGML_TYPE_F32 :
@@ -146,7 +296,36 @@ void ggml_cl_sgemm_wrapper(
146
296
dequant = true;
147
297
kernel = kernel_q4_2 ;
148
298
local = 8 ;
149
- size_qb = global * (sizeof (short ) + local ) / 16 ;
299
+ size_qb = global * (sizeof (ggml_fp16_t ) + local ) / 16 ;
300
+ break ;
301
+ case GGML_TYPE_Q5_0 :
302
+ dequant = true;
303
+ kernel = kernel_q5_0 ;
304
+ local = 16 ;
305
+ // For some reason OpenCL seems to be incapable of working with structs of size 22.
306
+ // 20 and 24 bytes are fine. Workaround to do the fp16 to fp32 step on CPU...
307
+ // TODO Find the reason, fix and remove workaround.
308
+ const block_q5_0 * b = (const block_q5_0 * ) host_b ;
309
+ cl_host_b = (cl_block_q5_0 * ) malloc (sizeof (cl_block_q5_0 ) * global / 32 );
310
+ for (size_t i = 0 ; i < global / 32 ; i ++ ) {
311
+ cl_host_b [i ].d = ggml_fp16_to_fp32 (b [i ].d );
312
+ memcpy (& cl_host_b [i ].qh , b [i ].qh , sizeof (uint32_t ));
313
+ memcpy (& cl_host_b [i ].qs , b [i ].qs , QK5_0 / 2 );
314
+ }
315
+ host_b = (const float * ) cl_host_b ;
316
+ size_qb = global * (sizeof (float ) + sizeof (uint32_t ) + local ) / 32 ;
317
+ break ;
318
+ case GGML_TYPE_Q5_1 :
319
+ dequant = true;
320
+ kernel = kernel_q5_1 ;
321
+ local = 16 ;
322
+ size_qb = global * (sizeof (ggml_fp16_t ) * 2 + sizeof (uint32_t ) + local ) / 32 ;
323
+ break ;
324
+ case GGML_TYPE_Q8_0 :
325
+ dequant = true;
326
+ kernel = kernel_q8_0 ;
327
+ local = 32 ;
328
+ size_qb = global * (sizeof (float ) + local ) / 32 ;
150
329
break ;
151
330
default :
152
331
fprintf (stderr , "Error: Unsupported OpenCL btype %d\n" , btype );
@@ -171,12 +350,15 @@ void ggml_cl_sgemm_wrapper(
171
350
err = clSetKernelArg (kernel , 0 , sizeof (cl_mem ), & cl_buffer_qb );
172
351
err |= clSetKernelArg (kernel , 1 , sizeof (cl_mem ), & cl_buffer_b );
173
352
CL_CHECK (err , "clSetKernelArg" );
174
- clEnqueueWriteBuffer (queue , cl_buffer_qb , CL_FALSE , 0 , size_qb , host_b , 0 , NULL , & ev_qb );
353
+ err = clEnqueueWriteBuffer (queue , cl_buffer_qb , CL_FALSE , 0 , size_qb , host_b , 0 , NULL , & ev_qb );
354
+ CL_CHECK (err , "clEnqueueWriteBuffer qb" );
175
355
} else {
176
- clEnqueueWriteBuffer (queue , cl_buffer_b , CL_FALSE , 0 , size_b , host_b , 0 , NULL , & ev_b );
356
+ err = clEnqueueWriteBuffer (queue , cl_buffer_b , CL_FALSE , 0 , size_b , host_b , 0 , NULL , & ev_b );
357
+ CL_CHECK (err , "clEnqueueWriteBuffer b" );
177
358
}
178
359
179
- clEnqueueWriteBuffer (queue , cl_buffer_a , CL_FALSE , 0 , size_a , host_a , 0 , NULL , & ev_a );
360
+ err = clEnqueueWriteBuffer (queue , cl_buffer_a , CL_FALSE , 0 , size_a , host_a , 0 , NULL , & ev_a );
361
+ CL_CHECK (err , "clEnqueueWriteBuffer a" );
180
362
if (dequant ) {
181
363
err = clEnqueueNDRangeKernel (queue , kernel , 1 , NULL , & global , & local , 1 , & ev_qb , & ev_b );
182
364
CL_CHECK (err , "clEnqueueNDRangeKernel" );
@@ -188,15 +370,20 @@ void ggml_cl_sgemm_wrapper(
188
370
clReleaseEvent (ev_b );
189
371
190
372
cl_event ev_sgemm ;
191
- CLBlastSgemm ((CLBlastLayout )order ,
192
- (CLBlastTranspose )trans_a , (CLBlastTranspose )trans_b ,
193
- m , n , k ,
194
- alpha ,
195
- cl_buffer_a , 0 , lda ,
196
- cl_buffer_b , 0 , ldb ,
197
- beta ,
198
- cl_buffer_c , 0 , ldc ,
199
- & queue , & ev_sgemm );
373
+ CLBlastStatusCode status = CLBlastSgemm ((CLBlastLayout )order ,
374
+ (CLBlastTranspose )trans_a , (CLBlastTranspose )trans_b ,
375
+ m , n , k ,
376
+ alpha ,
377
+ cl_buffer_a , 0 , lda ,
378
+ cl_buffer_b , 0 , ldb ,
379
+ beta ,
380
+ cl_buffer_c , 0 , ldc ,
381
+ & queue , & ev_sgemm );
382
+
383
+ if (status != CLBlastSuccess ) {
384
+ fprintf (stderr , "Error: CLBlast SGEMM %d\n" , status );
385
+ abort ();
386
+ }
200
387
201
388
cl_event ev_c ;
202
389
clEnqueueReadBuffer (queue , cl_buffer_c , CL_TRUE , 0 , size_c , host_c , 1 , & ev_sgemm , & ev_c );
@@ -205,4 +392,7 @@ void ggml_cl_sgemm_wrapper(
205
392
clWaitForEvents (1 , & ev_c );
206
393
clReleaseEvent (ev_sgemm );
207
394
clReleaseEvent (ev_c );
395
+ if (btype == GGML_TYPE_Q5_0 ) {
396
+ free ((void * ) cl_host_b );
397
+ }
208
398
}
0 commit comments