Skip to content

Commit ed267f2

Browse files
committed
fix quantize bug
1 parent 7d8c960 commit ed267f2

File tree

1 file changed

+35
-17
lines changed

1 file changed

+35
-17
lines changed

ggml/src/ggml-sycl.cpp

Lines changed: 35 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -358,10 +358,11 @@ static void pad_f32(const float *x, float *dst, const int ne0, const int ne00,
358358
}
359359
}
360360

361+
template<int QUANT_BLOCK_TILE>
361362
static void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int kx, const int kx_padded,
362363
const sycl::nd_item<3> &item_ct1) {
363-
const int ix = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
364-
item_ct1.get_local_id(2);
364+
const int ix = (item_ct1.get_local_range(2) * item_ct1.get_group(2) +
365+
item_ct1.get_local_id(2)) * QUANT_BLOCK_TILE;
365366

366367
if (ix >= kx_padded) {
367368
return;
@@ -376,23 +377,39 @@ static void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy,
376377

377378
const int ib = i_padded / QK8_1; // block index
378379
const int iqs = i_padded % QK8_1; // quant index
379-
380-
const float xi = ix < kx ? x[iy*kx + ix] : 0.0f;
381-
float amax = sycl::fabs((float)xi);
382-
float sum = xi;
383-
380+
typedef sycl::vec<float, QUANT_BLOCK_TILE> TC;
381+
typedef sycl::vec<int8_t, QUANT_BLOCK_TILE> TQ;
382+
TC zeros;
383+
TQ qzeros;
384384
#pragma unroll
385-
for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
386-
amax = sycl::fmax(amax, dpct::permute_sub_group_by_xor(
387-
item_ct1.get_sub_group(), amax, mask));
388-
sum +=
389-
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), sum, mask);
385+
for (int i = 0; i < QUANT_BLOCK_TILE; i++)
386+
{
387+
zeros[i] = 0.f;
388+
qzeros[i] = 0;
389+
}
390+
const TC xi = ix < kx ? *(TC *)&x[iy * kx + ix] : zeros;
391+
float sum = xi[0];
392+
float amax = sycl::fabs(xi[0]);
393+
#pragma unroll
394+
for (int i = 1; i < QUANT_BLOCK_TILE; i++)
395+
{
396+
sum += xi[i];
397+
amax = sycl::fmax(sycl::fabs(xi[i]), amax);
390398
}
399+
sum = warp_reduce_sum(sum, item_ct1);
400+
amax = warp_reduce_max(amax, item_ct1);
391401

392402
const float d = amax / 127;
393-
const int8_t q = amax == 0.0f ? 0 : sycl::round(xi / d);
403+
TQ q = qzeros;
404+
if (amax != 0.0f)
405+
{
406+
#pragma unroll
407+
for (int i = 0; i < QUANT_BLOCK_TILE; i++) {
408+
q[i] = sycl::round(xi[i] / d);
409+
}
410+
}
394411

395-
y[ib].qs[iqs] = q;
412+
*(TQ *)&y[ib].qs[iqs] = q;
396413

397414
if (iqs > 0) {
398415
return;
@@ -1595,15 +1612,17 @@ static void quantize_row_q8_1_sycl(const float *x, void *vy, const int kx,
15951612
queue_ptr stream) {
15961613
const int block_num_x = (kx_padded + SYCL_QUANTIZE_BLOCK_SIZE - 1) / SYCL_QUANTIZE_BLOCK_SIZE;
15971614
const sycl::range<3> num_blocks(1, ky, block_num_x);
1598-
const sycl::range<3> block_size(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE);
1615+
int constexpr QUANT_BLOCK_TILE = QK8_1 / WARP_SIZE;
1616+
static_assert(QK8_1 % WARP_SIZE == 0);
1617+
const sycl::range<3> block_size(1, 1, SYCL_QUANTIZE_BLOCK_SIZE / QUANT_BLOCK_TILE);
15991618
{
16001619
dpct::has_capability_or_fail(stream->get_device(),
16011620
{sycl::aspect::fp16});
16021621

16031622
stream->parallel_for(
16041623
sycl::nd_range<3>(num_blocks * block_size, block_size),
16051624
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
1606-
quantize_q8_1(x, vy, kx, kx_padded, item_ct1);
1625+
quantize_q8_1<QUANT_BLOCK_TILE>(x, vy, kx, kx_padded, item_ct1);
16071626
});
16081627
}
16091628
}
@@ -4170,7 +4189,6 @@ bool ggml_sycl_supports_dmmv(enum ggml_type type) {
41704189

41714190
static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
41724191
const bool split = ggml_backend_buffer_is_sycl_split(src0->buffer);
4173-
41744192
int64_t min_compute_capability = INT_MAX;
41754193

41764194
if (split) {

0 commit comments

Comments
 (0)