1
1
#include "ggml-opencl.h"
2
2
3
- #include < atomic>
4
- #include < cstdio>
5
- #include < cstring>
3
+ #define CL_TARGET_OPENCL_VERSION 110
4
+ #include <clblast_c.h>
5
+
6
+ #include <stdio.h>
7
+ #include <string.h>
6
8
7
9
#include "ggml.h"
8
10
9
11
#include <ggml_clblast_dequant.cl>
10
12
13
+ #define CL_CHECK (err , name ) \
14
+ do { \
15
+ cl_int err_ = (err); \
16
+ if (err_ != CL_SUCCESS) { \
17
+ fprintf(stderr, "OpenCL %s error %d at %s:%d\n", name, err_, __FILE__, __LINE__); \
18
+ exit(1); \
19
+ } \
20
+ } while (0)
21
+
11
22
cl_platform_id platform ;
12
23
cl_device_id device ;
13
24
cl_context context ;
@@ -74,7 +85,7 @@ void ggml_cl_init(void) {
74
85
printf ("Using Platform: %s Device: %s\n" , platform_buffer , device_buffer );
75
86
context = clCreateContext (NULL , 1 , & device , NULL , NULL , & err );
76
87
CL_CHECK (err , "clCreateContext" );
77
- queue = clCreateCommandQueue (context, device, 0 , &err);
88
+ queue = clCreateCommandQueue (context , device , CL_QUEUE_OUT_OF_ORDER_EXEC_MODE_ENABLE , & err );
78
89
CL_CHECK (err , "clCreateCommandQueue" );
79
90
80
91
free (platforms );
@@ -93,7 +104,7 @@ void ggml_cl_init(void) {
93
104
CL_CHECK (err , "clCreateKernel" );
94
105
}
95
106
96
- void ggml_cl_malloc (size_t req_size, size_t * cur_size, cl_mem_flags flags, cl_mem* buf) {
107
+ static void ggml_cl_malloc (size_t req_size , size_t * cur_size , cl_mem_flags flags , cl_mem * buf ) {
97
108
if (req_size <= * cur_size ) {
98
109
return ;
99
110
}
@@ -108,11 +119,14 @@ void ggml_cl_malloc(size_t req_size, size_t* cur_size, cl_mem_flags flags, cl_me
108
119
CL_CHECK (err , "clCreateBuffer" );
109
120
}
110
121
111
- void ggml_cl_sgemm_wrapper (const CLBlastLayout order, const CLBlastTranspose trans_a, const CLBlastTranspose trans_b, const int m, const int n, const int k, const float alpha, const void *host_a, const int lda, const float *host_b, const int ldb, const float beta, float *host_c, const int ldc, const int btype) {
122
+ void ggml_cl_sgemm_wrapper (
123
+ const enum ggml_blas_order order , const enum ggml_blas_op trans_a , const enum ggml_blas_op trans_b ,
124
+ const int m , const int n , const int k ,
125
+ const float alpha , const void * host_a , const int lda ,
126
+ const float * host_b , const int ldb , const float beta ,
127
+ float * host_c , const int ldc , const int btype ) {
112
128
cl_int err = 0 ;
113
129
114
- cl_event events[4 ] = { NULL };
115
-
116
130
cl_kernel kernel ;
117
131
size_t global = n * k , local , size_qb ;
118
132
bool dequant ;
@@ -162,42 +176,46 @@ void ggml_cl_sgemm_wrapper(const CLBlastLayout order, const CLBlastTranspose tra
162
176
ggml_cl_malloc (size_b , & cl_size_b , CL_MEM_READ_WRITE , & cl_buffer_b );
163
177
ggml_cl_malloc (size_c , & cl_size_c , CL_MEM_WRITE_ONLY , & cl_buffer_c );
164
178
179
+ cl_event ev_a , ev_qb , ev_b ;
180
+
165
181
if (dequant ) {
166
182
err = clSetKernelArg (kernel , 0 , sizeof (cl_mem ), & cl_buffer_qb );
167
183
err |= clSetKernelArg (kernel , 1 , sizeof (cl_mem ), & cl_buffer_b );
168
184
CL_CHECK (err , "clSetKernelArg" );
169
- clEnqueueWriteBuffer (queue, cl_buffer_qb, CL_FALSE, 0 , size_qb, host_b, 0 , NULL , events + 1 );
185
+ clEnqueueWriteBuffer (queue , cl_buffer_qb , CL_FALSE , 0 , size_qb , host_b , 0 , NULL , & ev_qb );
170
186
} else {
171
- clEnqueueWriteBuffer (queue, cl_buffer_b, CL_FALSE, 0 , size_b, host_b, 0 , NULL , events + 1 );
187
+ clEnqueueWriteBuffer (queue , cl_buffer_b , CL_FALSE , 0 , size_b , host_b , 0 , NULL , & ev_b );
172
188
}
173
189
174
- clEnqueueWriteBuffer (queue, cl_buffer_a, CL_FALSE, 0 , size_a, host_a, 0 , NULL , events );
190
+ clEnqueueWriteBuffer (queue , cl_buffer_a , CL_FALSE , 0 , size_a , host_a , 0 , NULL , & ev_a );
175
191
if (dequant ) {
176
- err = clEnqueueNDRangeKernel (queue, kernel, 1 , NULL , &global, &local, 1 , events + 1 , events + 3 );
192
+ err = clEnqueueNDRangeKernel (queue , kernel , 1 , NULL , & global , & local , 1 , & ev_qb , & ev_b );
177
193
CL_CHECK (err , "clEnqueueNDRangeKernel" );
178
194
}
179
- clWaitForEvents (dequant ? 4 : 3 , events );
180
- clReleaseEvent (events[ 0 ] );
181
- clReleaseEvent (events[ 1 ] );
182
- clReleaseEvent (events[ 2 ] );
195
+ clWaitForEvents (1 , & ev_a );
196
+ clWaitForEvents ( 1 , & ev_b );
197
+ clReleaseEvent (ev_a );
198
+ clReleaseEvent (ev_b );
183
199
if (dequant ) {
184
- clReleaseEvent (events[ 3 ] );
200
+ clReleaseEvent (ev_qb );
185
201
}
186
202
187
- CLBlastSgemm (order,
188
- trans_a, trans_b,
203
+ cl_event ev_sgemm ;
204
+ CLBlastSgemm ((CLBlastLayout )order ,
205
+ (CLBlastTranspose )trans_a , (CLBlastTranspose )trans_b ,
189
206
m , n , k ,
190
207
alpha ,
191
208
cl_buffer_a , 0 , lda ,
192
209
cl_buffer_b , 0 , ldb ,
193
210
beta ,
194
211
cl_buffer_c , 0 , ldc ,
195
- &queue, events );
212
+ & queue , & ev_sgemm );
196
213
197
- clEnqueueReadBuffer (queue, cl_buffer_c, CL_TRUE, 0 , size_c, host_c, 1 , events, events + 1 );
214
+ cl_event ev_c ;
215
+ clEnqueueReadBuffer (queue , cl_buffer_c , CL_TRUE , 0 , size_c , host_c , 1 , & ev_sgemm , & ev_c );
198
216
199
217
// Wait for completion
200
- clWaitForEvents (2 , events );
201
- clReleaseEvent (events[ 0 ] );
202
- clReleaseEvent (events[ 1 ] );
218
+ clWaitForEvents (1 , & ev_c );
219
+ clReleaseEvent (ev_sgemm );
220
+ clReleaseEvent (ev_c );
203
221
}
0 commit comments