Skip to content

Commit 6745ea7

Browse files
author
S
committed
dranger003: Fix block index overflow in CUDA dequantizing.
1 parent c2658c3 commit 6745ea7

File tree

3 files changed

+8
-8
lines changed

3 files changed

+8
-8
lines changed

ggml-cuda/common.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -394,7 +394,7 @@ static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {
394394
// TODO: move to ggml-common.h
395395
static const __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
396396

397-
typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, dfloat2 & v);
397+
typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, dfloat2 & v);
398398

399399

400400
//////////////////////

ggml-cuda/dequantize.cuh

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#include "common.cuh"
22

3-
static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const int ib, const int iqs, dfloat2 & v){
3+
static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
44
const block_q4_0 * x = (const block_q4_0 *) vx;
55

66
const dfloat d = x[ib].d;
@@ -19,7 +19,7 @@ static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const in
1919
#endif // GGML_CUDA_F16
2020
}
2121

22-
static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const int ib, const int iqs, dfloat2 & v){
22+
static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
2323
const block_q4_1 * x = (const block_q4_1 *) vx;
2424

2525
const dfloat d = __low2half(x[ib].dm);
@@ -39,7 +39,7 @@ static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const in
3939
#endif // GGML_CUDA_F16
4040
}
4141

42-
static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const int ib, const int iqs, dfloat2 & v){
42+
static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
4343
const block_q5_0 * x = (const block_q5_0 *) vx;
4444

4545
const dfloat d = x[ib].d;
@@ -62,7 +62,7 @@ static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const in
6262
#endif // GGML_CUDA_F16
6363
}
6464

65-
static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const int ib, const int iqs, dfloat2 & v){
65+
static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
6666
const block_q5_1 * x = (const block_q5_1 *) vx;
6767

6868
const dfloat d = __low2half(x[ib].dm);
@@ -86,7 +86,7 @@ static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const in
8686
#endif // GGML_CUDA_F16
8787
}
8888

89-
static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const int ib, const int iqs, dfloat2 & v){
89+
static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
9090
const block_q8_0 * x = (const block_q8_0 *) vx;
9191

9292
const dfloat d = x[ib].d;

ggml-cuda/dmmv.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -565,7 +565,7 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx,
565565
}
566566
}
567567

568-
static __device__ void convert_f16(const void * vx, const int ib, const int iqs, dfloat2 & v){
568+
static __device__ void convert_f16(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
569569
const half * x = (const half *) vx;
570570

571571
// automatic half -> float type cast if dfloat == float
@@ -598,7 +598,7 @@ static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, cons
598598

599599
for (int i = 0; i < ncols; i += iter_stride) {
600600
const int col = i + vals_per_iter*tid;
601-
const int ib = (row*ncols + col)/qk; // x block index
601+
const int64_t ib = ((int64_t)row*ncols + col)/qk; // x block index
602602
const int iqs = (col%qk)/qr; // x quant index
603603
const int iybs = col - col%qk; // y block start index
604604

0 commit comments

Comments
 (0)