@@ -1415,8 +1415,8 @@ static __device__ void convert_f16(const void * vx, const int ib, const int iqs,
1415
1415
}
1416
1416
1417
1417
static __global__ void quantize_q8_1 (
1418
- const float * __restrict__ x , void * __restrict__ vy , const int kx, const int kx_padded, const int ky,
1419
- const int row_stride , const int channel_stride) {
1418
+ const float * __restrict__ src , void * __restrict__ vdst , const int kx, const int kx_padded, const int ky,
1419
+ const int ky_stride , const int channel_stride) {
1420
1420
1421
1421
const int ix = blockDim .x *blockIdx .x + threadIdx .x ;
1422
1422
@@ -1427,14 +1427,17 @@ static __global__ void quantize_q8_1(
1427
1427
const int iy = blockDim .y *blockIdx .y + threadIdx .y ;
1428
1428
const int channel = blockDim .z *blockIdx .z + threadIdx .z ;
1429
1429
1430
+ // padded and contiguous:
1430
1431
const int i_padded = channel*ky*kx_padded + iy*kx_padded + ix;
1431
1432
1432
- block_q8_1 * y = (block_q8_1 *) vy ;
1433
+ block_q8_1 * dst = (block_q8_1 *) vdst ;
1433
1434
1434
1435
const int ib = i_padded / QK8_1; // block index
1435
1436
const int iqs = i_padded % QK8_1; // quant index
1436
1437
1437
- const float xi = ix < kx ? x[channel*channel_stride + iy*row_stride + ix] : 0 .0f ;
1438
+ // not padded and not necessarily contiguous:
1439
+ const float xi = ix < kx ? src[channel*channel_stride + iy*ky_stride + ix] : 0 .0f ;
1440
+
1438
1441
float amax = fabsf (xi);
1439
1442
float sum = xi;
1440
1443
@@ -1447,14 +1450,14 @@ static __global__ void quantize_q8_1(
1447
1450
const float d = amax / 127 ;
1448
1451
const int8_t q = amax == 0 .0f ? 0 : roundf (xi / d);
1449
1452
1450
- y [ib].qs [iqs] = q;
1453
+ dst [ib].qs [iqs] = q;
1451
1454
1452
1455
if (iqs > 0 ) {
1453
1456
return ;
1454
1457
}
1455
1458
1456
- reinterpret_cast <half&>(y [ib].ds .x ) = d;
1457
- reinterpret_cast <half&>(y [ib].ds .y ) = sum;
1459
+ reinterpret_cast <half&>(dst [ib].ds .x ) = d;
1460
+ reinterpret_cast <half&>(dst [ib].ds .y ) = sum;
1458
1461
}
1459
1462
1460
1463
template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
0 commit comments