Skip to content

Commit 7296c96

Browse files
0cc4mLostRuinsslarenSlyEchoggerganov
authored
ggml : add CLBlast support (#1164)
* Allow use of OpenCL GPU-based BLAS using ClBlast instead of OpenBLAS for context processing * Improve ClBlast implementation, avoid recreating buffers, remove redundant transfers * Finish merge of ClBlast support * Move CLBlast implementation to separate file Add buffer reuse code (adapted from slaren's cuda implementation) * Add q4_2 and q4_3 CLBlast support, improve code * Double CLBlast speed by disabling OpenBLAS thread workaround Co-authored-by: Concedo <[email protected]> Co-authored-by: slaren <[email protected]> * Fix device selection env variable names * Fix cast in opencl kernels * Add CLBlast to CMakeLists.txt * Replace buffer pool with static buffers a, b, qb, c Fix compile warnings * Fix typos, use GGML_TYPE defines, improve code * Improve btype dequant kernel selection code, add error if type is unsupported * Improve code quality * Move internal stuff out of header * Use internal enums instead of CLBlast enums * Remove leftover C++ includes and defines * Make event use easier to read Co-authored-by: Henri Vasserman <[email protected]> * Use c compiler for opencl files * Simplify code, fix include * First check error, then release event * Make globals static, fix indentation * Rename dequant kernels file to conform with other file names * Fix import cl file name --------- Co-authored-by: Concedo <[email protected]> Co-authored-by: slaren <[email protected]> Co-authored-by: Henri Vasserman <[email protected]> Co-authored-by: Georgi Gerganov <[email protected]>
1 parent 78ec543 commit 7296c96

File tree

8 files changed

+411
-16
lines changed

8 files changed

+411
-16
lines changed

CMakeLists.txt

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ endif()
6767
option(LLAMA_ACCELERATE "llama: enable Accelerate framework" ON)
6868
option(LLAMA_OPENBLAS "llama: use OpenBLAS" OFF)
6969
option(LLAMA_CUBLAS "llama: use cuBLAS" OFF)
70+
option(LLAMA_CLBLAST "llama: use CLBlast" OFF)
7071

7172
option(LLAMA_BUILD_TESTS "llama: build tests" ${LLAMA_STANDALONE})
7273
option(LLAMA_BUILD_EXAMPLES "llama: build examples" ${LLAMA_STANDALONE})
@@ -168,6 +169,21 @@ if (LLAMA_CUBLAS)
168169
endif()
169170
endif()
170171

172+
if (LLAMA_CLBLAST)
173+
find_package(CLBlast)
174+
if (CLBlast_FOUND)
175+
message(STATUS "CLBlast found")
176+
177+
set(GGML_OPENCL_SOURCES ggml-opencl.c ggml-opencl.h)
178+
179+
add_compile_definitions(GGML_USE_CLBLAST)
180+
181+
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} clblast)
182+
else()
183+
message(WARNING "CLBlast not found")
184+
endif()
185+
endif()
186+
171187
if (LLAMA_ALL_WARNINGS)
172188
if (NOT MSVC)
173189
set(c_flags
@@ -307,7 +323,8 @@ endif()
307323
add_library(ggml OBJECT
308324
ggml.c
309325
ggml.h
310-
${GGML_CUDA_SOURCES})
326+
${GGML_CUDA_SOURCES}
327+
${GGML_OPENCL_SOURCES})
311328

312329
target_include_directories(ggml PUBLIC .)
313330
target_compile_features(ggml PUBLIC c_std_11) # don't bump

Makefile

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,14 +105,21 @@ ifdef LLAMA_OPENBLAS
105105
LDFLAGS += -lopenblas
106106
endif
107107
ifdef LLAMA_CUBLAS
108-
CFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I$(CUDA_PATH)/targets/x86_64-linux/include
109-
LDFLAGS += -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L$(CUDA_PATH)/targets/x86_64-linux/lib
108+
CFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/x86_64-linux/include
109+
LDFLAGS += -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/x86_64-linux/lib
110110
OBJS += ggml-cuda.o
111111
NVCC = nvcc
112112
NVCCFLAGS = --forward-unknown-to-host-compiler -arch=native
113113
ggml-cuda.o: ggml-cuda.cu ggml-cuda.h
114114
$(NVCC) $(NVCCFLAGS) $(CXXFLAGS) -Wno-pedantic -c $< -o $@
115115
endif
116+
ifdef LLAMA_CLBLAST
117+
CFLAGS += -DGGML_USE_CLBLAST
118+
LDFLAGS += -lclblast -lOpenCL
119+
OBJS += ggml-opencl.o
120+
ggml-opencl.o: ggml-opencl.c ggml-opencl.h
121+
$(CC) $(CFLAGS) -c $< -o $@
122+
endif
116123
ifdef LLAMA_GPROF
117124
CFLAGS += -pg
118125
CXXFLAGS += -pg

ggml-opencl-dequant.cl

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
#define MULTILINE_QUOTE(...) #__VA_ARGS__
2+
const char * clblast_dequant = MULTILINE_QUOTE(
3+
4+
struct block_q4_0
5+
{
6+
float d;
7+
uchar qs[16];
8+
};
9+
10+
__kernel void dequantize_row_q4_0(__global struct block_q4_0* blocks, __global float* result) {
11+
const uint i = get_global_id(0) / 32;
12+
const uint l = get_local_id(0);
13+
14+
const float d = blocks[i].d;
15+
16+
const uchar vi = blocks[i].qs[l];
17+
18+
const uint index = i*32 + l*2;
19+
result[index + 0] = ((vi & 0xf) - 8)*d;
20+
result[index + 1] = ((vi >> 4) - 8)*d;
21+
}
22+
23+
struct block_q4_1
24+
{
25+
float d;
26+
float m;
27+
uchar qs[16];
28+
};
29+
30+
__kernel void dequantize_row_q4_1(__global struct block_q4_1* blocks, __global float* result) {
31+
const uint i = get_global_id(0) / 32;
32+
const uint l = get_local_id(0);
33+
34+
const float d = blocks[i].d;
35+
const float m = blocks[i].m;
36+
37+
const uchar vi = blocks[i].qs[l];
38+
39+
const uint index = i*32 + l*2;
40+
result[index + 0] = (vi & 0xf) * d + m;
41+
result[index + 1] = (vi >> 4) * d + m;
42+
}
43+
44+
struct block_q4_2
45+
{
46+
ushort d;
47+
uchar qs[8];
48+
};
49+
50+
__kernel void dequantize_row_q4_2(__global struct block_q4_2* blocks, __global float* result) {
51+
const uint i = get_global_id(0) / 16;
52+
const uint l = get_local_id(0);
53+
54+
const float d = vload_half(0, (__global half*) &blocks[i].d);;
55+
56+
const uchar vi = blocks[i].qs[l];
57+
58+
const uint index = i*16 + l*2;
59+
result[index + 0] = ((vi & 0xf) - 8)*d;
60+
result[index + 1] = ((vi >> 4) - 8)*d;
61+
}
62+
63+
struct block_q4_3
64+
{
65+
ushort d;
66+
ushort m;
67+
uchar qs[8];
68+
};
69+
70+
__kernel void dequantize_row_q4_3(__global struct block_q4_3* blocks, __global float* result) {
71+
const uint i = get_global_id(0) / 16;
72+
const uint l = get_local_id(0);
73+
74+
const float d = vload_half(0, (__global half*) &(blocks[i].d));
75+
const float m = vload_half(0, (__global half*) &(blocks[i].m));
76+
77+
const uchar vi = blocks[i].qs[l];
78+
79+
const uint index = i*16 + l*2;
80+
result[index + 0] = (vi & 0xf) * d + m;
81+
result[index + 1] = (vi >> 4) * d + m;
82+
}
83+
84+
);

ggml-opencl.c

Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
#include "ggml-opencl.h"
2+
3+
#define CL_TARGET_OPENCL_VERSION 110
4+
#include <clblast_c.h>
5+
6+
#include <stdio.h>
7+
#include <string.h>
8+
9+
#include "ggml.h"
10+
11+
#include "ggml-opencl-dequant.cl"
12+
13+
#define CL_CHECK(err, name) \
14+
do { \
15+
cl_int err_ = (err); \
16+
if (err_ != CL_SUCCESS) { \
17+
fprintf(stderr, "OpenCL %s error %d at %s:%d\n", name, err_, __FILE__, __LINE__); \
18+
exit(1); \
19+
} \
20+
} while (0)
21+
22+
static cl_platform_id platform;
23+
static cl_device_id device;
24+
static cl_context context;
25+
static cl_command_queue queue;
26+
static cl_program program;
27+
static cl_kernel kernel_q4_0, kernel_q4_1, kernel_q4_2, kernel_q4_3;
28+
static cl_mem cl_buffer_a, cl_buffer_qb, cl_buffer_b, cl_buffer_c;
29+
static size_t cl_size_a = 0, cl_size_qb = 0, cl_size_b = 0, cl_size_c = 0;
30+
31+
static cl_program build_program_from_source(cl_context ctx, cl_device_id dev, const char* program_buffer) {
32+
cl_program p;
33+
char *program_log;
34+
size_t program_size, log_size;
35+
int err;
36+
37+
program_size = strlen(program_buffer);
38+
39+
p = clCreateProgramWithSource(ctx, 1, (const char**)&program_buffer, &program_size, &err);
40+
if(err < 0) {
41+
fprintf(stderr, "OpenCL error creating program");
42+
exit(1);
43+
}
44+
45+
err = clBuildProgram(p, 0, NULL, NULL, NULL, NULL);
46+
if(err < 0) {
47+
48+
clGetProgramBuildInfo(p, dev, CL_PROGRAM_BUILD_LOG, 0, NULL, &log_size);
49+
program_log = (char*) malloc(log_size + 1);
50+
program_log[log_size] = '\0';
51+
clGetProgramBuildInfo(p, dev, CL_PROGRAM_BUILD_LOG, log_size + 1, program_log, NULL);
52+
printf("%s\n", program_log);
53+
free(program_log);
54+
exit(1);
55+
}
56+
57+
return p;
58+
}
59+
60+
void ggml_cl_init(void) {
61+
cl_int err = 0;
62+
char * GGML_CLBLAST_PLATFORM = getenv("GGML_CLBLAST_PLATFORM");
63+
char * GGML_CLBLAST_DEVICE = getenv("GGML_CLBLAST_DEVICE");
64+
int plat_num = (GGML_CLBLAST_PLATFORM == NULL ? 0 : atoi(GGML_CLBLAST_PLATFORM));
65+
int dev_num = (GGML_CLBLAST_DEVICE == NULL ? 0 : atoi(GGML_CLBLAST_DEVICE));
66+
printf("\nInitializing CLBlast (First Run)...");
67+
printf("\nAttempting to use: Platform=%d, Device=%d (If invalid, program will crash)\n",plat_num,dev_num);
68+
cl_uint num_platforms;
69+
clGetPlatformIDs(0, NULL, &num_platforms);
70+
cl_platform_id* platforms = (cl_platform_id*)malloc(num_platforms*sizeof(cl_platform_id));
71+
clGetPlatformIDs(num_platforms, platforms, NULL);
72+
platform = platforms[plat_num];
73+
char platform_buffer[1024];
74+
clGetPlatformInfo(platform, CL_PLATFORM_NAME, sizeof(platform_buffer), &platform_buffer, NULL);
75+
cl_uint num_devices;
76+
clGetDeviceIDs(platform, CL_DEVICE_TYPE_ALL, 0, NULL, &num_devices);
77+
cl_device_id* devices = (cl_device_id*)malloc(num_devices*sizeof(cl_device_id));
78+
clGetDeviceIDs(platform, CL_DEVICE_TYPE_ALL, num_devices, devices, NULL);
79+
device = devices[dev_num];
80+
char device_buffer[1024];
81+
clGetDeviceInfo(device, CL_DEVICE_NAME, sizeof(device_buffer), &device_buffer, NULL);
82+
printf("Using Platform: %s Device: %s\n", platform_buffer, device_buffer);
83+
context = clCreateContext(NULL, 1, &device, NULL, NULL, &err);
84+
CL_CHECK(err, "clCreateContext");
85+
queue = clCreateCommandQueue(context, device, CL_QUEUE_OUT_OF_ORDER_EXEC_MODE_ENABLE, &err);
86+
CL_CHECK(err, "clCreateCommandQueue");
87+
88+
free(platforms);
89+
free(devices);
90+
91+
program = build_program_from_source(context, device, clblast_dequant);
92+
93+
// Prepare dequantize kernels
94+
kernel_q4_0 = clCreateKernel(program, "dequantize_row_q4_0", &err);
95+
CL_CHECK(err, "clCreateKernel");
96+
kernel_q4_1 = clCreateKernel(program, "dequantize_row_q4_1", &err);
97+
CL_CHECK(err, "clCreateKernel");
98+
kernel_q4_2 = clCreateKernel(program, "dequantize_row_q4_2", &err);
99+
CL_CHECK(err, "clCreateKernel");
100+
kernel_q4_3 = clCreateKernel(program, "dequantize_row_q4_3", &err);
101+
CL_CHECK(err, "clCreateKernel");
102+
}
103+
104+
static void ggml_cl_malloc(size_t req_size, size_t* cur_size, cl_mem_flags flags, cl_mem* buf) {
105+
if (req_size <= *cur_size) {
106+
return;
107+
}
108+
109+
// Reallocate buffer with enough space
110+
if (*cur_size > 0) {
111+
clReleaseMemObject(*buf);
112+
}
113+
cl_int err;
114+
*buf = clCreateBuffer(context, flags, req_size, NULL, &err);
115+
*cur_size = req_size;
116+
CL_CHECK(err, "clCreateBuffer");
117+
}
118+
119+
void ggml_cl_sgemm_wrapper(
120+
const enum ggml_blas_order order, const enum ggml_blas_op trans_a, const enum ggml_blas_op trans_b,
121+
const int m, const int n, const int k,
122+
const float alpha, const void *host_a, const int lda,
123+
const float *host_b, const int ldb, const float beta,
124+
float *host_c, const int ldc, const int btype) {
125+
cl_int err = 0;
126+
127+
cl_kernel kernel;
128+
size_t global = n * k, local, size_qb;
129+
bool dequant;
130+
131+
switch (btype) {
132+
case GGML_TYPE_F32:
133+
dequant = false;
134+
break;
135+
case GGML_TYPE_Q4_0:
136+
dequant = true;
137+
kernel = kernel_q4_0;
138+
local = 16;
139+
size_qb = global * (sizeof(float) + local) / 32;
140+
break;
141+
case GGML_TYPE_Q4_1:
142+
dequant = true;
143+
kernel = kernel_q4_1;
144+
local = 16;
145+
size_qb = global * (sizeof(float) * 2 + local) / 32;
146+
break;
147+
case GGML_TYPE_Q4_2:
148+
dequant = true;
149+
kernel = kernel_q4_2;
150+
local = 8;
151+
size_qb = global * (sizeof(short) + local) / 16;
152+
break;
153+
case GGML_TYPE_Q4_3:
154+
dequant = true;
155+
kernel = kernel_q4_3;
156+
local = 8;
157+
size_qb = global * (sizeof(short) * 2 + local) / 16;
158+
break;
159+
default:
160+
fprintf(stderr, "Error: Unsupported OpenCL btype %d\n", btype);
161+
abort();
162+
}
163+
164+
const size_t size_a = m * k * sizeof(float);
165+
const size_t size_b = n * k * sizeof(float);
166+
const size_t size_c = m * n * sizeof(float);
167+
168+
// Prepare buffers
169+
ggml_cl_malloc(size_a, &cl_size_a, CL_MEM_READ_ONLY, &cl_buffer_a);
170+
if (dequant) {
171+
ggml_cl_malloc(size_qb, &cl_size_qb, CL_MEM_READ_ONLY, &cl_buffer_qb);
172+
}
173+
ggml_cl_malloc(size_b, &cl_size_b, CL_MEM_READ_WRITE, &cl_buffer_b);
174+
ggml_cl_malloc(size_c, &cl_size_c, CL_MEM_WRITE_ONLY, &cl_buffer_c);
175+
176+
cl_event ev_a, ev_qb, ev_b;
177+
178+
if (dequant) {
179+
err = clSetKernelArg(kernel, 0, sizeof(cl_mem), &cl_buffer_qb);
180+
err |= clSetKernelArg(kernel, 1, sizeof(cl_mem), &cl_buffer_b);
181+
CL_CHECK(err, "clSetKernelArg");
182+
clEnqueueWriteBuffer(queue, cl_buffer_qb, CL_FALSE, 0, size_qb, host_b, 0, NULL, &ev_qb);
183+
} else {
184+
clEnqueueWriteBuffer(queue, cl_buffer_b, CL_FALSE, 0, size_b, host_b, 0, NULL, &ev_b);
185+
}
186+
187+
clEnqueueWriteBuffer(queue, cl_buffer_a, CL_FALSE, 0, size_a, host_a, 0, NULL, &ev_a);
188+
if (dequant) {
189+
err = clEnqueueNDRangeKernel(queue, kernel, 1, NULL, &global, &local, 1, &ev_qb, &ev_b);
190+
CL_CHECK(err, "clEnqueueNDRangeKernel");
191+
clReleaseEvent(ev_qb);
192+
}
193+
clWaitForEvents(1, &ev_a);
194+
clWaitForEvents(1, &ev_b);
195+
clReleaseEvent(ev_a);
196+
clReleaseEvent(ev_b);
197+
198+
cl_event ev_sgemm;
199+
CLBlastSgemm((CLBlastLayout)order,
200+
(CLBlastTranspose)trans_a, (CLBlastTranspose)trans_b,
201+
m, n, k,
202+
alpha,
203+
cl_buffer_a, 0, lda,
204+
cl_buffer_b, 0, ldb,
205+
beta,
206+
cl_buffer_c, 0, ldc,
207+
&queue, &ev_sgemm);
208+
209+
cl_event ev_c;
210+
clEnqueueReadBuffer(queue, cl_buffer_c, CL_TRUE, 0, size_c, host_c, 1, &ev_sgemm, &ev_c);
211+
212+
// Wait for completion
213+
clWaitForEvents(1, &ev_c);
214+
clReleaseEvent(ev_sgemm);
215+
clReleaseEvent(ev_c);
216+
}

ggml-opencl.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
#pragma once
2+
3+
#ifdef __cplusplus
4+
extern "C" {
5+
#endif
6+
7+
void ggml_cl_init(void);
8+
9+
enum ggml_blas_order {
10+
GGML_BLAS_ORDER_ROW_MAJOR = 101,
11+
GGML_BLAS_ORDER_COLUMN_MAJOR = 102,
12+
};
13+
14+
enum ggml_blas_op {
15+
GGML_BLAS_OP_N = 111,
16+
GGML_BLAS_OP_T = 112,
17+
GGML_BLAS_OP_C = 113,
18+
};
19+
20+
void ggml_cl_sgemm_wrapper(const enum ggml_blas_order order, const enum ggml_blas_op trans_a, const enum ggml_blas_op trans_b, const int m, const int n, const int k, const float alpha, const void *host_a, const int lda, const float *host_b, const int ldb, const float beta, float *host_c, const int ldc, const int btype);
21+
22+
#ifdef __cplusplus
23+
}
24+
#endif

0 commit comments

Comments
 (0)