Skip to content

Commit 663445b

Browse files
AD2605Alcpz
andauthored
sycl: quantize and reorder the input to q8_1 when reorder is enabled (#13826)
* [WIP]: fuse q8 quantization and reorder * wip2: fuse q8 quantization and reorder * working q8 reorder commit * restored common.hpp * remove debug prints * remove unnecessary headers and remove trailing whitespace * Update ggml/src/ggml-sycl/ggml-sycl.cpp Co-authored-by: Alberto Cabrera Pérez <[email protected]> --------- Co-authored-by: Alberto Cabrera Pérez <[email protected]>
1 parent 7675c55 commit 663445b

File tree

3 files changed

+120
-28
lines changed

3 files changed

+120
-28
lines changed

ggml/src/ggml-sycl/ggml-sycl.cpp

Lines changed: 79 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1434,6 +1434,59 @@ static void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy,
14341434
reinterpret_cast<sycl::half &>(y[ib].ds.y()) = sum;
14351435
}
14361436

1437+
template <int ElementsPerWI>
1438+
static __dpct_inline__ void quantize_and_reorder_q8_1(const float * __restrict__ x, void * reordered_q8_tensor,
1439+
const int kx, const int kx_padded, const sycl::nd_item<1> & it) {
1440+
/*
1441+
Quantizes and reorders the resultant q8 tensor in a per row fashion
1442+
Each sub-group calculates one quant block. i.e. QK8_1 quant values and the d and sum values
1443+
*/
1444+
1445+
auto subgroup_id = it.get_group(0);
1446+
auto wi_id = it.get_local_id(0);
1447+
1448+
const int num_blocks_per_row = kx / QK8_1;
1449+
auto row = subgroup_id / num_blocks_per_row;
1450+
auto col = subgroup_id % num_blocks_per_row;
1451+
1452+
auto row_offset = row * (kx_padded / QK8_1) * sizeof(block_q8_1);
1453+
auto col_offset = QK8_1 * col + wi_id * ElementsPerWI;
1454+
1455+
auto quant_ptr = (int8_t *) ((char *) reordered_q8_tensor + row_offset + col_offset);
1456+
auto ds_ptr = (sycl::half2 *) ((char *) reordered_q8_tensor + row_offset + kx + col * sizeof(sycl::half2));
1457+
1458+
sycl::vec<float, ElementsPerWI> wi_f32_vals;
1459+
sycl::vec<int8_t, ElementsPerWI> quantized_values;
1460+
1461+
auto float_ptr_offset = subgroup_id * QK8_1 + ElementsPerWI * wi_id;
1462+
wi_f32_vals = *reinterpret_cast<const sycl::vec<float, ElementsPerWI> *>(x + float_ptr_offset);
1463+
1464+
float sum = 0.0f;
1465+
float amax = 0.0f;
1466+
1467+
#pragma unroll(ElementsPerWI)
1468+
for (int i = 0; i < ElementsPerWI; i++) {
1469+
sum += wi_f32_vals[i];
1470+
amax = sycl::fmax(amax, sycl::fabs(wi_f32_vals[i]));
1471+
quantized_values[i] = 0;
1472+
}
1473+
sum = sycl::reduce_over_group(it.get_group(), sum, sycl::plus<float>());
1474+
amax = sycl::reduce_over_group(it.get_group(), amax, sycl::maximum<float>());
1475+
float d = amax == 0 ? 1 : amax / 127;
1476+
1477+
#pragma unroll(ElementsPerWI)
1478+
for (int i = 0; i < ElementsPerWI; i++) {
1479+
quantized_values[i] = sycl::round(wi_f32_vals[i] / d);
1480+
}
1481+
1482+
d = amax == 0 ? 0 : d;
1483+
1484+
*reinterpret_cast<sycl::vec<int8_t, ElementsPerWI> *>(quant_ptr) = quantized_values;
1485+
if (wi_id == 0) {
1486+
*ds_ptr = sycl::half2(sycl::half(d), sycl::half(sum));
1487+
}
1488+
}
1489+
14371490
static void mul_mat_p021_f16_f32(
14381491
const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst,
14391492
const int ncols_x, const int nrows_x, const int nchannels_x, const int nchannels_y,
@@ -1718,23 +1771,30 @@ static void pool2d_nchw_kernel(
17181771
o_ptr[cur_oh * ow + cur_ow] = res;
17191772
}
17201773

1721-
static void quantize_row_q8_1_sycl(const float *x, void *vy, const int kx,
1722-
const int ky, const int kx_padded,
1723-
queue_ptr stream) {
1724-
const int block_num_x = (kx_padded + SYCL_QUANTIZE_BLOCK_SIZE - 1) / SYCL_QUANTIZE_BLOCK_SIZE;
1725-
const sycl::range<3> num_blocks(1, ky, block_num_x);
1726-
int constexpr QUANT_BLOCK_TILE = QK8_1 / WARP_SIZE;
1727-
static_assert(QK8_1 % WARP_SIZE == 0);
1728-
const sycl::range<3> block_size(1, 1, SYCL_QUANTIZE_BLOCK_SIZE / QUANT_BLOCK_TILE);
1729-
{
1730-
dpct::has_capability_or_fail(stream->get_device(),
1731-
{sycl::aspect::fp16});
1774+
static void quantize_row_q8_1_sycl(const float * x, void * vy, const int kx, const int ky, const int kx_padded,
1775+
bool reorder_q8_tensor, queue_ptr stream) {
1776+
if (reorder_q8_tensor) {
1777+
auto local_range = std::size_t(WARP_SIZE);
1778+
auto num_quant_blocks = ky * (kx / QK8_1);
1779+
auto global_range = num_quant_blocks * local_range;
1780+
stream->parallel_for(sycl::nd_range<1>({ global_range }, { local_range }),
1781+
[=](sycl::nd_item<1> it) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
1782+
quantize_and_reorder_q8_1<QK8_1 / WARP_SIZE>(x, vy, kx, kx_padded, it);
1783+
});
1784+
} else {
1785+
const int block_num_x = (kx_padded + SYCL_QUANTIZE_BLOCK_SIZE - 1) / SYCL_QUANTIZE_BLOCK_SIZE;
1786+
const sycl::range<3> num_blocks(1, ky, block_num_x);
1787+
int constexpr QUANT_BLOCK_TILE = QK8_1 / WARP_SIZE;
1788+
static_assert(QK8_1 % WARP_SIZE == 0);
1789+
const sycl::range<3> block_size(1, 1, SYCL_QUANTIZE_BLOCK_SIZE / QUANT_BLOCK_TILE);
1790+
{
1791+
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
17321792

1733-
stream->parallel_for(
1734-
sycl::nd_range<3>(num_blocks * block_size, block_size),
1735-
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
1736-
quantize_q8_1<QUANT_BLOCK_TILE>(x, vy, kx, kx_padded, item_ct1);
1737-
});
1793+
stream->parallel_for(sycl::nd_range<3>(num_blocks * block_size, block_size),
1794+
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
1795+
quantize_q8_1<QUANT_BLOCK_TILE>(x, vy, kx, kx_padded, item_ct1);
1796+
});
1797+
}
17381798
}
17391799
}
17401800

@@ -2446,9 +2506,10 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
24462506
dev[i].src1_ddq = dev[i].src1_ddq_alloc.alloc(ctx.pool(i), nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs);
24472507

24482508
if (src1_on_device && src1_is_contiguous) {
2509+
bool reorder_q8_tensor = src0->extra && ((ggml_tensor_extra_gpu *)src0->extra)->optimized_feature.reorder;
24492510
scope_op_debug_print scope_dbg_print(__func__, "/quantize_row_q8_1_sycl", dst,
24502511
/*num_src=*/2, " : converting src1 to Q8_1");
2451-
quantize_row_q8_1_sycl(dev[i].src1_ddf, dev[i].src1_ddq, ne10, nrows1, src1_padded_col_size, stream);
2512+
quantize_row_q8_1_sycl(dev[i].src1_ddf, dev[i].src1_ddq, ne10, nrows1, src1_padded_col_size, reorder_q8_tensor, stream);
24522513
/*
24532514
DPCT1010:90: SYCL uses exceptions to report errors and does not
24542515
use the error codes. The call was replaced with 0. You need to
@@ -2554,7 +2615,7 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
25542615
if (convert_src1_to_q8_1 && !src1_is_contiguous) {
25552616
scope_op_debug_print scope_dbg_print(__func__, "/quantize_row_q8_1_sycl", dst,
25562617
/*num_src=*/2, " : converting src1 to Q8_1");
2557-
quantize_row_q8_1_sycl(src1_ddf_i, src1_ddq_i, ne10, src1_ncols, src1_padded_col_size, stream);
2618+
quantize_row_q8_1_sycl(src1_ddf_i, src1_ddq_i, ne10, src1_ncols, src1_padded_col_size, false, stream);
25582619
/*
25592620
DPCT1010:92: SYCL uses exceptions to report errors and does
25602621
not use the error codes. The call was replaced with 0. You

ggml/src/ggml-sycl/mmvq.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,6 @@ static void mul_mat_vec_q_reorder(const void * __restrict__ vx, const void * __r
2929
static_assert(blocks_per_subgroup > 0);
3030
static_assert(block_elements_per_subgroup > 0);
3131

32-
const block_q8_1 * y = (const block_q8_1 *) vy;
33-
3432
float partial_sum = 0.0f;
3533
for (int i = sg.get_local_linear_id() / block_elements_per_subgroup; i < blocks_per_row; i += blocks_per_subgroup) {
3634
const int ibx = row * blocks_per_row + i; // x block index
@@ -40,13 +38,15 @@ static void mul_mat_vec_q_reorder(const void * __restrict__ vx, const void * __r
4038

4139
// Y block index that aligns with ibx
4240
const int iby = i * block_type::block_to_q8_1_ratio();
41+
const int8_t* q8_1_quant_ptr = (const int8_t*)vy + iby * QK8_1;
42+
const sycl::half2* q8_1_ds_ptr = (const sycl::half2*)((const char*)vy + ncols + iby * sizeof(sycl::half2));
4343

4444
#pragma unroll
4545
for (int elem = 0; elem < block_elements_per_subgroup; elem += WARP_SIZE) {
4646
// x block quant index when casting the quants to int
4747
const int iqs = elem + block_traits::vdr_mmvq * (sg.get_local_linear_id() % block_elements_per_subgroup);
4848

49-
partial_sum += reorder_vec_dot_q_sycl()(vx, bx_offset, d_offset, &y[iby], iqs, nblocks);
49+
partial_sum += reorder_vec_dot_q_sycl()(vx, bx_offset, d_offset, q8_1_quant_ptr, q8_1_ds_ptr, iqs, nblocks);
5050
}
5151
}
5252

ggml/src/ggml-sycl/vecdotq.hpp

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -285,21 +285,21 @@ template <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q4_0> {
285285
}
286286

287287
__dpct_inline__ float operator()(const void * __restrict__ vbq, const int ibx_offset, const int d_offset,
288-
const block_q8_1 * __restrict__ bq8_1, const int & iqs, int /* nblocks */) {
288+
const int8_t* q8_1_quant_ptr, const sycl::half2* q8_1_ds, const int & iqs, int /* nblocks */) {
289289
const uint8_t * bq4_0 = static_cast<const uint8_t *>(vbq) + ibx_offset;
290290
const ggml_half d = *(reinterpret_cast<const ggml_half *>(static_cast<const uint8_t *>(vbq) + d_offset));
291291
int v[q4_0_traits::vdr_mmvq];
292292
int u[2 * q4_0_traits::vdr_mmvq];
293293

294-
#pragma unroll
295294

295+
#pragma unroll
296296
for (size_t i = 0; i < q4_0_traits::vdr_mmvq; ++i) {
297297
v[i] = get_int_from_uint8(bq4_0, iqs + i);
298-
u[2 * i + 0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i);
299-
u[2 * i + 1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + q4_0_traits::qi);
298+
u[2 * i + 0] = get_int_from_int8_aligned(q8_1_quant_ptr, iqs + i);
299+
u[2 * i + 1] = get_int_from_int8_aligned(q8_1_quant_ptr, iqs + i + q4_0_traits::qi);
300300
}
301301

302-
return vec_dot_q4_0_q8_1_impl(v, u, d, bq8_1->ds);
302+
return vec_dot_q4_0_q8_1_impl(v, u, d, *q8_1_ds);
303303
};
304304
};
305305

@@ -347,7 +347,7 @@ template <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q4_K> {
347347
using q4_k_traits = typename q4_k_block::traits;
348348

349349
float operator()(const void * __restrict__ vbq, const int ibx_offset, const int d_offset,
350-
const block_q8_1 * __restrict__ bq8_1, const int & iqs, int nblocks) {
350+
const int8_t* q8_1_quant_ptr, const sycl::half2* q8_1_ds, const int & iqs, int nblocks) {
351351
const int ib = ibx_offset / (QK_K / 2);
352352

353353
const uint8_t * base = static_cast<const uint8_t *>(vbq);
@@ -360,7 +360,38 @@ template <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q4_K> {
360360
const int * q4 = (const int *) (qs + 16 * bq8_offset + 4 * ((iqs / 2) % 4));
361361
const uint16_t * scales = (const uint16_t *) scs;
362362

363-
return vec_dot_q4_K_q8_1_common(q4, scales, *dms, bq8_1, iqs);
363+
int v[2];
364+
int u[2 * QR4_K];
365+
float d8[QR4_K];
366+
367+
v[0] = q4[0];
368+
v[1] = q4[4];
369+
370+
uint16_t aux[2];
371+
const int j = (QR4_K * ((iqs / 2) / (QI8_1 / 2))) / 2;
372+
if (j < 2) {
373+
aux[0] = scales[j + 0] & 0x3f3f;
374+
aux[1] = scales[j + 2] & 0x3f3f;
375+
} else {
376+
aux[0] = ((scales[j + 2] >> 0) & 0x0f0f) | ((scales[j - 2] & 0xc0c0) >> 2);
377+
aux[1] = ((scales[j + 2] >> 4) & 0x0f0f) | ((scales[j - 0] & 0xc0c0) >> 2);
378+
}
379+
380+
const uint8_t * sc = (const uint8_t *) aux;
381+
const uint8_t * m = sc + 2;
382+
383+
for (int i = 0; i < QR4_K; ++i) {
384+
const int8_t* quant_base_ptr = q8_1_quant_ptr + (bq8_offset + i) * QK8_1;
385+
sycl::half2 ds_values = *(q8_1_ds + bq8_offset + i);
386+
387+
d8[i] = ds_values[0];
388+
389+
const int * q8 = (const int *) quant_base_ptr + ((iqs / 2) % 4);
390+
u[2 * i + 0] = q8[0];
391+
u[2 * i + 1] = q8[4];
392+
}
393+
394+
return vec_dot_q4_K_q8_1_impl_vmmq(v, u, sc, m, *dms, d8);
364395
}
365396
};
366397

0 commit comments

Comments
 (0)