Skip to content

Commit 3c4a83e

Browse files
small refactor
1 parent c0daa66 commit 3c4a83e

File tree

1 file changed

+10
-7
lines changed

1 file changed

+10
-7
lines changed

ggml-cuda.cu

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1415,8 +1415,8 @@ static __device__ void convert_f16(const void * vx, const int ib, const int iqs,
14151415
}
14161416

14171417
static __global__ void quantize_q8_1(
1418-
const float * __restrict__ x, void * __restrict__ vy, const int kx, const int kx_padded, const int ky,
1419-
const int row_stride, const int channel_stride) {
1418+
const float * __restrict__ src, void * __restrict__ vdst, const int kx, const int kx_padded, const int ky,
1419+
const int ky_stride, const int channel_stride) {
14201420

14211421
const int ix = blockDim.x*blockIdx.x + threadIdx.x;
14221422

@@ -1427,14 +1427,17 @@ static __global__ void quantize_q8_1(
14271427
const int iy = blockDim.y*blockIdx.y + threadIdx.y;
14281428
const int channel = blockDim.z*blockIdx.z + threadIdx.z;
14291429

1430+
// padded and contiguous:
14301431
const int i_padded = channel*ky*kx_padded + iy*kx_padded + ix;
14311432

1432-
block_q8_1 * y = (block_q8_1 *) vy;
1433+
block_q8_1 * dst = (block_q8_1 *) vdst;
14331434

14341435
const int ib = i_padded / QK8_1; // block index
14351436
const int iqs = i_padded % QK8_1; // quant index
14361437

1437-
const float xi = ix < kx ? x[channel*channel_stride + iy*row_stride + ix] : 0.0f;
1438+
// not padded and not necessarily contiguous:
1439+
const float xi = ix < kx ? src[channel*channel_stride + iy*ky_stride + ix] : 0.0f;
1440+
14381441
float amax = fabsf(xi);
14391442
float sum = xi;
14401443

@@ -1447,14 +1450,14 @@ static __global__ void quantize_q8_1(
14471450
const float d = amax / 127;
14481451
const int8_t q = amax == 0.0f ? 0 : roundf(xi / d);
14491452

1450-
y[ib].qs[iqs] = q;
1453+
dst[ib].qs[iqs] = q;
14511454

14521455
if (iqs > 0) {
14531456
return;
14541457
}
14551458

1456-
reinterpret_cast<half&>(y[ib].ds.x) = d;
1457-
reinterpret_cast<half&>(y[ib].ds.y) = sum;
1459+
reinterpret_cast<half&>(dst[ib].ds.x) = d;
1460+
reinterpret_cast<half&>(dst[ib].ds.y) = sum;
14581461
}
14591462

14601463
template <int qk, int qr, dequantize_kernel_t dequantize_kernel>

0 commit comments

Comments
 (0)