Skip to content

Commit f600d0d

Browse files
CUDA: faster q8_0 -> f16 dequantization
1 parent 326b418 commit f600d0d

File tree

1 file changed

+57
-0
lines changed

1 file changed

+57
-0
lines changed

ggml-cuda.cu

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -519,6 +519,8 @@ static_assert(sizeof(block_iq2_xs) == sizeof(ggml_fp16_t) + QK_K/8*sizeof(uint16
519519
#define CUDA_ACC_BLOCK_SIZE 256
520520
#define CUDA_IM2COL_BLOCK_SIZE 256
521521

522+
#define CUDA_Q8_0_NE_ALIGN 2048
523+
522524
// dmmv = dequantize_mul_mat_vec
523525
#ifndef GGML_CUDA_DMMV_X
524526
#define GGML_CUDA_DMMV_X 32
@@ -2327,6 +2329,45 @@ static __global__ void convert_unary(const void * __restrict__ vx, dst_t * __res
23272329
y[i] = x[i];
23282330
}
23292331

2332+
template <bool need_check>
2333+
static __global__ void dequantize_block_q8_0_f16(const void * __restrict__ vx, half * __restrict__ y, const int k) {
2334+
#if __CUDA_ARCH__ >= CC_PASCAL
2335+
constexpr int nint = CUDA_Q8_0_NE_ALIGN/sizeof(int) + WARP_SIZE;
2336+
2337+
const int i0 = CUDA_Q8_0_NE_ALIGN*blockIdx.x;
2338+
const int * x0 = ((int *) vx) + blockIdx.x * nint;
2339+
half2 * y2 = (half2 *) (y + i0);
2340+
2341+
__shared__ int vals[nint];
2342+
2343+
#pragma unroll
2344+
for (int ix0 = 0; ix0 < nint; ix0 += WARP_SIZE) {
2345+
if (need_check && i0*sizeof(block_q8_0)/QK8_0 + sizeof(int)*(ix0 + threadIdx.x) >= k*sizeof(block_q8_0)/QK8_0) {
2346+
break;
2347+
}
2348+
2349+
const int ix = ix0 + threadIdx.x;
2350+
vals[ix] = x0[ix];
2351+
}
2352+
2353+
#pragma unroll
2354+
for (int iy = 0; iy < CUDA_Q8_0_NE_ALIGN; iy += 2*WARP_SIZE) {
2355+
if (need_check && i0 + iy + 2*threadIdx.x >= k) {
2356+
return;
2357+
}
2358+
2359+
const half * b0 = ((const half *) vals) + (sizeof(block_q8_0)/sizeof(half)) * ((iy + 2*threadIdx.x)/QK8_0);
2360+
const half d = *b0;
2361+
const char2 qs = ((const char2 *) (b0 + 1))[threadIdx.x % (QK8_0/2)];
2362+
2363+
y2[iy/2 + threadIdx.x] = __hmul2(make_half2(qs.x, qs.y), __half2half2(d));
2364+
}
2365+
#else
2366+
(void) vx; (void) y; (void) k;
2367+
bad_arch();
2368+
#endif // __CUDA_ARCH__ >= CC_PASCAL
2369+
}
2370+
23302371
// VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called
23312372
// MMVQ = mul_mat_vec_q, MMQ = mul_mat_q
23322373

@@ -6181,6 +6222,17 @@ static void dequantize_block_cuda(const void * __restrict__ vx, dst_t * __restri
61816222
dequantize_block<qk, qr, dequantize_kernel><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
61826223
}
61836224

6225+
static void dequantize_block_q8_0_f16_cuda(const void * __restrict__ vx, half * __restrict__ y, const int k, cudaStream_t stream) {
6226+
const int num_blocks = (k + CUDA_Q8_0_NE_ALIGN - 1) / CUDA_Q8_0_NE_ALIGN;
6227+
if (k % CUDA_Q8_0_NE_ALIGN == 0) {
6228+
const bool need_check = false;
6229+
dequantize_block_q8_0_f16<need_check><<<num_blocks, WARP_SIZE, 0, stream>>>(vx, y, k);
6230+
} else {
6231+
const bool need_check = true;
6232+
dequantize_block_q8_0_f16<need_check><<<num_blocks, WARP_SIZE, 0, stream>>>(vx, y, k);
6233+
}
6234+
}
6235+
61846236
template<typename dst_t>
61856237
static void dequantize_row_q2_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
61866238
const int nb = k / QK_K;
@@ -6246,6 +6298,7 @@ static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict_
62466298
}
62476299

62486300
static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
6301+
int id;
62496302
switch (type) {
62506303
case GGML_TYPE_Q4_0:
62516304
return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>;
@@ -6256,6 +6309,10 @@ static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
62566309
case GGML_TYPE_Q5_1:
62576310
return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
62586311
case GGML_TYPE_Q8_0:
6312+
CUDA_CHECK(cudaGetDevice(&id));
6313+
if (g_device_caps[id].cc >= CC_PASCAL) {
6314+
return dequantize_block_q8_0_f16_cuda;
6315+
}
62596316
return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
62606317
case GGML_TYPE_Q2_K:
62616318
return dequantize_row_q2_K_cuda;

0 commit comments

Comments
 (0)