Skip to content

Commit f547c58

Browse files
committed
Make ggml-cuda.cu build with QK_K = 64
Using LLAMA_CUDA_FORCE_DMMV = ON and -nommq it runs and produces a meaningful result.
1 parent 771551a commit f547c58

File tree

1 file changed

+17
-8
lines changed

1 file changed

+17
-8
lines changed

ggml-cuda.cu

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -306,11 +306,11 @@ typedef struct {
306306
#define QI4_K (QK_K / (4*QR4_K))
307307
#ifdef GGML_QKK_64
308308
typedef struct {
309-
half d[2]; // super-block scales/mins
309+
half dm[2]; // super-block scales/mins
310310
uint8_t scales[2]; // 4-bit block scales/mins
311311
uint8_t qs[QK_K/2]; // 4--bit quants
312312
} block_q4_K;
313-
static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + QK_K/2 + 2, "wrong q4_K block size/padding");
313+
static_assert(sizeof(block_q4_K) == sizeof(half2) + QK_K/2 + 2, "wrong q4_K block size/padding");
314314
#else
315315
typedef struct {
316316
half2 dm; // super-block scale for quantized scales/mins
@@ -737,8 +737,8 @@ static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, float
737737
const int tid = threadIdx.x;
738738
const uint8_t * q = x[i].qs;
739739
float * y = yy + i*QK_K;
740-
const float d = (float)x[i].d[0];
741-
const float m = (float)x[i].d[1];
740+
const float d = (float)x[i].dm[0];
741+
const float m = (float)x[i].dm[1];
742742
y[tid+ 0] = d * (x[i].scales[0] & 0xF) * (q[tid] & 0xF) - m * (x[i].scales[0] >> 4);
743743
y[tid+32] = d * (x[i].scales[1] & 0xF) * (q[tid] >> 4) - m * (x[i].scales[1] >> 4);
744744
#endif
@@ -1155,8 +1155,8 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx,
11551155
const uint16_t * a = (const uint16_t *)x[i].scales;
11561156
aux16[0] = a[0] & 0x0f0f;
11571157
aux16[1] = (a[0] >> 4) & 0x0f0f;
1158-
const float d = (float)x[i].d[0];
1159-
const float m = (float)x[i].d[1];
1158+
const float d = (float)x[i].dm[0];
1159+
const float m = (float)x[i].dm[1];
11601160
float sum = 0.f;
11611161
for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) {
11621162
sum += y[j+ 0] * (d * s[0] * (q[j+ 0] & 0xF) - m * s[2])
@@ -2845,8 +2845,8 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
28452845
aux16[0] = a[0] & 0x0f0f;
28462846
aux16[1] = (a[0] >> 4) & 0x0f0f;
28472847

2848-
const float dall = bq4_K->d[0];
2849-
const float dmin = bq4_K->d[1];
2848+
const float dall = bq4_K->dm[0];
2849+
const float dmin = bq4_K->dm[1];
28502850

28512851
const float d8_1 = __low2float(bq8_1[0].ds);
28522852
const float d8_2 = __low2float(bq8_1[1].ds);
@@ -2929,7 +2929,11 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
29292929

29302930
const block_q4_K * bxi = bx0 + i*blocks_per_row + kbxd;
29312931

2932+
#if QK_K == 256
29322933
x_dm[i * (WARP_SIZE/QI4_K) + i / QI4_K + kbxd] = bxi->dm;
2934+
#else
2935+
x_dm[i * (WARP_SIZE/QI4_K) + i / QI4_K + kbxd] = {bxi->dm[0], bxi->dm[1]};
2936+
#endif
29332937
}
29342938

29352939
#pragma unroll
@@ -3119,7 +3123,9 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
31193123

31203124
const block_q5_K * bxi = bx0 + i*blocks_per_row + kbxd;
31213125

3126+
#if QK_K == 256
31223127
x_dm[i * (WARP_SIZE/QI5_K) + i / QI5_K + kbxd] = bxi->dm;
3128+
#endif
31233129
}
31243130

31253131
#pragma unroll
@@ -4709,6 +4715,8 @@ static void ggml_mul_mat_q3_K_q8_1_cuda(
47094715
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
47104716
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
47114717

4718+
#if QK_K == 256
4719+
47124720
int id;
47134721
CUDA_CHECK(cudaGetDevice(&id));
47144722
const int compute_capability = g_compute_capabilities[id];
@@ -4740,6 +4748,7 @@ static void ggml_mul_mat_q3_K_q8_1_cuda(
47404748
mul_mat_q3_K<need_check><<<block_nums, block_dims, 0, stream>>>
47414749
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
47424750
}
4751+
#endif
47434752
}
47444753

47454754
static void ggml_mul_mat_q4_K_q8_1_cuda(

0 commit comments

Comments
 (0)