Skip to content

Commit 02d6988

Browse files
authored
Improve cuBLAS performance by dequantizing on the GPU (#1065)
1 parent 834695f commit 02d6988

File tree

5 files changed

+221
-41
lines changed

5 files changed

+221
-41
lines changed

CMakeLists.txt

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ if (APPLE AND LLAMA_ACCELERATE)
110110
message(WARNING "Accelerate framework not found")
111111
endif()
112112
endif()
113+
113114
if (LLAMA_OPENBLAS)
114115
if (LLAMA_STATIC)
115116
set(BLA_STATIC ON)
@@ -150,6 +151,10 @@ if (LLAMA_CUBLAS)
150151
if (CUDAToolkit_FOUND)
151152
message(STATUS "cuBLAS found")
152153

154+
enable_language(CUDA)
155+
156+
set(GGML_CUDA_SOURCES ggml-cuda.cu ggml-cuda.h)
157+
153158
add_compile_definitions(GGML_USE_CUBLAS)
154159

155160
if (LLAMA_STATIC)
@@ -241,21 +246,26 @@ elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "^(x86_64|i686|AMD64)$")
241246
message(STATUS "x86 detected")
242247
if (MSVC)
243248
if (LLAMA_AVX512)
244-
add_compile_options(/arch:AVX512)
249+
add_compile_options($<$<COMPILE_LANGUAGE:C>:/arch:AVX512>)
250+
add_compile_options($<$<COMPILE_LANGUAGE:CXX>:/arch:AVX512>)
245251
# MSVC has no compile-time flags enabling specific
246252
# AVX512 extensions, neither it defines the
247253
# macros corresponding to the extensions.
248254
# Do it manually.
249255
if (LLAMA_AVX512_VBMI)
250-
add_compile_definitions(__AVX512VBMI__)
256+
add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512VBMI__>)
257+
add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512VBMI__>)
251258
endif()
252259
if (LLAMA_AVX512_VNNI)
253-
add_compile_definitions(__AVX512VNNI__)
260+
add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512VNNI__>)
261+
add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512VNNI__>)
254262
endif()
255263
elseif (LLAMA_AVX2)
256-
add_compile_options(/arch:AVX2)
264+
add_compile_options($<$<COMPILE_LANGUAGE:C>:/arch:AVX2>)
265+
add_compile_options($<$<COMPILE_LANGUAGE:CXX>:/arch:AVX2>)
257266
elseif (LLAMA_AVX)
258-
add_compile_options(/arch:AVX)
267+
add_compile_options($<$<COMPILE_LANGUAGE:C>:/arch:AVX>)
268+
add_compile_options($<$<COMPILE_LANGUAGE:CXX>:/arch:AVX>)
259269
endif()
260270
else()
261271
if (LLAMA_F16C)
@@ -292,7 +302,8 @@ endif()
292302

293303
add_library(ggml OBJECT
294304
ggml.c
295-
ggml.h)
305+
ggml.h
306+
${GGML_CUDA_SOURCES})
296307

297308
target_include_directories(ggml PUBLIC .)
298309
target_compile_features(ggml PUBLIC c_std_11) # don't bump
@@ -314,6 +325,14 @@ if (BUILD_SHARED_LIBS)
314325
target_compile_definitions(llama PRIVATE LLAMA_SHARED LLAMA_BUILD)
315326
endif()
316327

328+
if (GGML_CUDA_SOURCES)
329+
message(STATUS "GGML CUDA sources found, configuring CUDA architecture")
330+
set_property(TARGET ggml PROPERTY CUDA_ARCHITECTURES OFF)
331+
set_property(TARGET ggml PROPERTY CUDA_SELECT_NVCC_ARCH_FLAGS "Auto")
332+
set_property(TARGET llama PROPERTY CUDA_ARCHITECTURES OFF)
333+
endif()
334+
335+
317336
#
318337
# programs, examples and tests
319338
#

Makefile

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
# Define the default target now so that it is always the first target
2+
default: main quantize quantize-stats perplexity embedding vdot
3+
14
ifndef UNAME_S
25
UNAME_S := $(shell uname -s)
36
endif
@@ -100,6 +103,9 @@ endif
100103
ifdef LLAMA_CUBLAS
101104
CFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include
102105
LDFLAGS += -lcublas_static -lculibos -lcudart_static -lcublasLt_static -lpthread -ldl -L/usr/local/cuda/lib64
106+
OBJS += ggml-cuda.o
107+
ggml-cuda.o: ggml-cuda.cu ggml-cuda.h
108+
nvcc -arch=native -c -o $@ $<
103109
endif
104110
ifdef LLAMA_GPROF
105111
CFLAGS += -pg
@@ -137,8 +143,6 @@ $(info I CC: $(CCV))
137143
$(info I CXX: $(CXXV))
138144
$(info )
139145

140-
default: main quantize quantize-stats perplexity embedding vdot
141-
142146
#
143147
# Build library
144148
#
@@ -155,35 +159,35 @@ common.o: examples/common.cpp examples/common.h
155159
clean:
156160
rm -vf *.o main quantize quantize-stats perplexity embedding benchmark-q4_0-matmult
157161

158-
main: examples/main/main.cpp ggml.o llama.o common.o
162+
main: examples/main/main.cpp ggml.o llama.o common.o $(OBJS)
159163
$(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS)
160164
@echo
161165
@echo '==== Run ./main -h for help. ===='
162166
@echo
163167

164-
quantize: examples/quantize/quantize.cpp ggml.o llama.o
168+
quantize: examples/quantize/quantize.cpp ggml.o llama.o $(OBJS)
165169
$(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS)
166170

167-
quantize-stats: examples/quantize-stats/quantize-stats.cpp ggml.o llama.o
171+
quantize-stats: examples/quantize-stats/quantize-stats.cpp ggml.o llama.o $(OBJS)
168172
$(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS)
169173

170-
perplexity: examples/perplexity/perplexity.cpp ggml.o llama.o common.o
174+
perplexity: examples/perplexity/perplexity.cpp ggml.o llama.o common.o $(OBJS)
171175
$(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS)
172176

173-
embedding: examples/embedding/embedding.cpp ggml.o llama.o common.o
177+
embedding: examples/embedding/embedding.cpp ggml.o llama.o common.o $(OBJS)
174178
$(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS)
175179

176-
vdot: pocs/vdot/vdot.cpp ggml.o
180+
vdot: pocs/vdot/vdot.cpp ggml.o $(OBJS)
177181
$(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS)
178182

179-
libllama.so: llama.o ggml.o
183+
libllama.so: llama.o ggml.o $(OBJS)
180184
$(CXX) $(CXXFLAGS) -shared -fPIC -o $@ $^ $(LDFLAGS)
181185

182186
#
183187
# Tests
184188
#
185189

186-
benchmark: examples/benchmark/benchmark-q4_0-matmult.c ggml.o
190+
benchmark: examples/benchmark/benchmark-q4_0-matmult.c ggml.o $(OBJS)
187191
$(CXX) $(CXXFLAGS) $^ -o benchmark-q4_0-matmult $(LDFLAGS)
188192
./benchmark-q4_0-matmult
189193

ggml-cuda.cu

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
#include <stdint.h>
2+
#include <cuda_fp16.h>
3+
#include "ggml-cuda.h"
4+
5+
typedef uint16_t ggml_fp16_t;
6+
static_assert(sizeof(__half) == sizeof(ggml_fp16_t), "wrong fp16 size");
7+
8+
#define QK4_0 32
9+
typedef struct {
10+
float d; // delta
11+
uint8_t qs[QK4_0 / 2]; // nibbles / quants
12+
} block_q4_0;
13+
static_assert(sizeof(block_q4_0) == sizeof(float) + QK4_0 / 2, "wrong q4_0 block size/padding");
14+
15+
#define QK4_1 32
16+
typedef struct {
17+
float d; // delta
18+
float m; // min
19+
uint8_t qs[QK4_1 / 2]; // nibbles / quants
20+
} block_q4_1;
21+
static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK4_1 / 2, "wrong q4_1 block size/padding");
22+
23+
#define QK4_2 16
24+
typedef struct {
25+
__half d; // delta
26+
uint8_t qs[QK4_2 / 2]; // nibbles / quants
27+
} block_q4_2;
28+
static_assert(sizeof(block_q4_2) == sizeof(ggml_fp16_t) + QK4_2 / 2, "wrong q4_2 block size/padding");
29+
30+
31+
static __global__ void dequantize_block_q4_0(const void * vx, float * y) {
32+
const block_q4_0 * x = (const block_q4_0 *) vx;
33+
34+
const int i = blockIdx.x;
35+
36+
const float d = x[i].d;
37+
38+
const uint8_t * pp = x[i].qs;
39+
40+
for (int l = 0; l < QK4_0; l += 2) {
41+
const uint8_t vi = pp[l/2];
42+
43+
const int8_t vi0 = vi & 0xf;
44+
const int8_t vi1 = vi >> 4;
45+
46+
const float v0 = (vi0 - 8)*d;
47+
const float v1 = (vi1 - 8)*d;
48+
49+
y[i*QK4_0 + l + 0] = v0;
50+
y[i*QK4_0 + l + 1] = v1;
51+
}
52+
}
53+
54+
static __global__ void dequantize_block_q4_1(const void * vx, float * y) {
55+
const block_q4_1 * x = (const block_q4_1 *) vx;
56+
57+
const int i = blockIdx.x;
58+
59+
const float d = x[i].d;
60+
const float m = x[i].m;
61+
62+
const uint8_t * pp = x[i].qs;
63+
64+
for (int l = 0; l < QK4_1; l += 2) {
65+
const uint8_t vi = pp[l/2];
66+
67+
const int8_t vi0 = vi & 0xf;
68+
const int8_t vi1 = vi >> 4;
69+
70+
const float v0 = vi0*d + m;
71+
const float v1 = vi1*d + m;
72+
73+
y[i*QK4_1 + l + 0] = v0;
74+
y[i*QK4_1 + l + 1] = v1;
75+
}
76+
}
77+
78+
static __global__ void dequantize_block_q4_2(const void * vx, float * y) {
79+
const block_q4_2 * x = (const block_q4_2 *) vx;
80+
81+
const int i = blockIdx.x;
82+
83+
const float d = x[i].d;
84+
85+
const uint8_t * pp = x[i].qs;
86+
87+
for (int l = 0; l < QK4_2; l += 2) {
88+
const uint8_t vi = pp[l/2];
89+
90+
const int8_t vi0 = vi & 0xf;
91+
const int8_t vi1 = vi >> 4;
92+
93+
const float v0 = (vi0 - 8)*d;
94+
const float v1 = (vi1 - 8)*d;
95+
96+
y[i*QK4_2 + l + 0] = v0;
97+
y[i*QK4_2 + l + 1] = v1;
98+
}
99+
}
100+
101+
extern "C" {
102+
__host__ void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
103+
const int nb = k / QK4_0;
104+
dequantize_block_q4_0<<<nb, 1, 0, stream>>>(vx, y);
105+
}
106+
107+
__host__ void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
108+
const int nb = k / QK4_1;
109+
dequantize_block_q4_1<<<nb, 1, 0, stream>>>(vx, y);
110+
}
111+
112+
__host__ void dequantize_row_q4_2_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
113+
const int nb = k / QK4_2;
114+
dequantize_block_q4_2<<<nb, 1, 0, stream>>>(vx, y);
115+
}
116+
}

ggml-cuda.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
#ifdef __cplusplus
2+
extern "C" {
3+
#endif
4+
5+
void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream);
6+
void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream);
7+
void dequantize_row_q4_2_cuda(const void * vx, float * y, int k, cudaStream_t stream);
8+
9+
#ifdef __cplusplus
10+
}
11+
#endif

0 commit comments

Comments
 (0)