Skip to content

Commit b460d16

Browse files
authored
sycl: Add reorder to Q6_K mmvq implementation (#13885)
* Add Reorder to Q6_K mmvq implementation * Address PR comments: clean up comments * Remove unused parameter after refactoring q4_k * Adding inline to function and removing unnecessary reference to int --------- Signed-off-by: nscipione <[email protected]>
1 parent 91a8ee6 commit b460d16

File tree

6 files changed

+244
-30
lines changed

6 files changed

+244
-30
lines changed

ggml/src/ggml-sycl/convert.cpp

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,17 @@ static void dequantize_row_q6_K_sycl(const void *vx, dst_t *y, const int64_t k,
265265
#endif
266266
}
267267

268+
template <typename dst_t>
269+
static void dequantize_row_q6_K_sycl_reorder(const void * vx, dst_t * y, const int64_t k, dpct::queue_ptr stream) {
270+
const int64_t nb = k / QK_K;
271+
272+
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
273+
274+
stream->parallel_for(
275+
sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 64), sycl::range<3>(1, 1, 64)),
276+
[=](sycl::nd_item<3> item_ct1) { dequantize_block_q6_K_reorder(vx, y, item_ct1, nb); });
277+
}
278+
268279
template <typename dst_t>
269280
static void dequantize_row_iq1_s_sycl(const void *vx, dst_t *y, const int64_t k,
270281
dpct::queue_ptr stream) {
@@ -530,7 +541,11 @@ to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst) {
530541
case GGML_TYPE_Q5_K:
531542
return dequantize_row_q5_K_sycl;
532543
case GGML_TYPE_Q6_K:
533-
return dequantize_row_q6_K_sycl;
544+
if (dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
545+
return dequantize_row_q6_K_sycl_reorder;
546+
} else {
547+
return dequantize_row_q6_K_sycl;
548+
}
534549
case GGML_TYPE_IQ1_S:
535550
return dequantize_row_iq1_s_sycl;
536551
case GGML_TYPE_IQ1_M:
@@ -587,7 +602,11 @@ to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) {
587602
case GGML_TYPE_Q5_K:
588603
return dequantize_row_q5_K_sycl;
589604
case GGML_TYPE_Q6_K:
590-
return dequantize_row_q6_K_sycl;
605+
if (dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
606+
return dequantize_row_q6_K_sycl_reorder;
607+
} else {
608+
return dequantize_row_q6_K_sycl;
609+
}
591610
case GGML_TYPE_IQ1_S:
592611
return dequantize_row_iq1_s_sycl;
593612
case GGML_TYPE_IQ1_M:

ggml/src/ggml-sycl/dequantize.hpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -538,6 +538,38 @@ static void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restri
538538
#endif
539539
}
540540

541+
template <typename dst_t>
542+
static void dequantize_block_q6_K_reorder(const void * __restrict__ vx, dst_t * __restrict__ yy,
543+
const sycl::nd_item<3> & item_ct1, int64_t n_blocks) {
544+
const int64_t ib = item_ct1.get_group(2);
545+
546+
const int64_t tid = item_ct1.get_local_id(2);
547+
const int64_t ip = tid / 32; // ip is 0 or 1
548+
const int64_t il = tid - 32 * ip; // 0...32
549+
const int64_t is = 8 * ip + il / 16;
550+
551+
const uint8_t * base_ptr = static_cast<const uint8_t *>(vx);
552+
const auto ql_offset = ib * (QK_K / 2);
553+
const auto qh_offset = (QK_K / 2) * n_blocks + (QK_K / 4) * ib;
554+
const auto base_scales_offset = (QK_K / 2) * n_blocks + (QK_K / 4) * n_blocks + (QK_K / 16) * ib;
555+
const auto base_d_offset = ((QK_K / 2) + (QK_K / 4) + (QK_K / 16)) * n_blocks;
556+
const uint8_t * ql_ptr = base_ptr + ql_offset;
557+
const uint8_t * qh_ptr = base_ptr + qh_offset;
558+
const uint8_t * scales_ptr = base_ptr + base_scales_offset;
559+
const ggml_half * d = (const ggml_half *) (base_ptr + base_d_offset) + ib;
560+
561+
dst_t * y = yy + ib * QK_K + 128 * ip + il;
562+
563+
const uint8_t * ql = ql_ptr + 64 * ip + il;
564+
const uint8_t qh = *(qh_ptr + 32 * ip + il);
565+
const int8_t * sc = reinterpret_cast<const int8_t *>(scales_ptr + is);
566+
567+
y[0] = *d * sc[0] * ((int8_t) ((ql[0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32);
568+
y[32] = *d * sc[2] * ((int8_t) ((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32);
569+
y[64] = *d * sc[4] * ((int8_t) ((ql[0] >> 4) | (((qh >> 4) & 3) << 4)) - 32);
570+
y[96] = *d * sc[6] * ((int8_t) ((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32);
571+
}
572+
541573
template<typename dst_t>
542574
static void dequantize_block_iq2_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy,
543575
const sycl::nd_item<3> &item_ct1,

ggml/src/ggml-sycl/ggml-sycl.cpp

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,8 @@ ggml_backend_sycl_buffer_init_tensor(ggml_backend_buffer_t buffer,
354354
assert(tensor->view_src->buffer->buft == buffer->buft);
355355
return GGML_STATUS_SUCCESS;
356356
}
357-
if ((tensor->type == GGML_TYPE_Q4_0 || tensor->type == GGML_TYPE_Q4_K) && !g_ggml_sycl_disable_optimize) {
357+
if ((tensor->type == GGML_TYPE_Q4_0 || tensor->type == GGML_TYPE_Q4_K || tensor->type == GGML_TYPE_Q6_K) &&
358+
!g_ggml_sycl_disable_optimize) {
358359
ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{};
359360
tensor->extra = extra;
360361
ctx->tensor_extras.push_back(extra); //used to release it when destroy ctx.
@@ -2989,6 +2990,7 @@ inline bool ggml_sycl_supports_reorder_mul_mat_sycl(enum ggml_type type) {
29892990
case GGML_TYPE_Q4_0:
29902991
return true;
29912992
case GGML_TYPE_Q4_K:
2993+
case GGML_TYPE_Q6_K:
29922994
return !g_ggml_sycl_prioritize_dmmv;
29932995
default:
29942996
return false;
@@ -3008,6 +3010,7 @@ inline bool ggml_sycl_supports_reorder_mmvq(enum ggml_type type) {
30083010
switch (type) {
30093011
case GGML_TYPE_Q4_0:
30103012
case GGML_TYPE_Q4_K:
3013+
case GGML_TYPE_Q6_K:
30113014
return true;
30123015
default:
30133016
return false;
@@ -3092,6 +3095,50 @@ static void reorder_qw_q4_k(uint8_t * data_device, size_t size, size_t offset, d
30923095
sycl::free(tmp_buf, *stream);
30933096
}
30943097

3098+
static void reorder_qw_q6_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) {
3099+
GGML_ASSERT(size % sizeof(block_q6_K) == 0);
3100+
GGML_ASSERT(offset % sizeof(block_q6_K) == 0);
3101+
3102+
const int nblocks = size / sizeof(block_q6_K);
3103+
3104+
auto * tmp_buf = sycl::malloc_shared<uint8_t>(size, *stream);
3105+
SYCL_CHECK(CHECK_TRY_ERROR((*stream).memcpy(tmp_buf, data_device, size).wait()));
3106+
3107+
auto * ql_ptr = data_device;
3108+
auto * qh_ptr = ql_ptr + (QK_K / 2) * nblocks;
3109+
auto * scales_ptr = qh_ptr + (QK_K / 4) * nblocks;
3110+
sycl::half * dm_ptr = (sycl::half *) (scales_ptr + (QK_K / 16) * nblocks);
3111+
3112+
stream
3113+
->parallel_for(nblocks,
3114+
[=](auto i) {
3115+
const block_q6_K * x = (const block_q6_K *) tmp_buf;
3116+
const int ib = i;
3117+
3118+
const uint8_t * ql = x[ib].ql;
3119+
const uint8_t * qh = x[ib].qh;
3120+
uint8_t * base_ql_ptr = ql_ptr + (QK_K / 2) * ib;
3121+
uint8_t * base_qh_ptr = qh_ptr + (QK_K / 4) * ib;
3122+
uint8_t * base_scales_ptr = scales_ptr + (QK_K / 16) * ib;
3123+
3124+
for (int j = 0; j < QK_K / 2; ++j) {
3125+
base_ql_ptr[j] = ql[j];
3126+
}
3127+
for (int j = 0; j < QK_K / 4; ++j) {
3128+
base_qh_ptr[j] = qh[j];
3129+
}
3130+
3131+
for (int j = 0; j < QK_K / 16; ++j) {
3132+
base_scales_ptr[j] = x[ib].scales[j];
3133+
}
3134+
3135+
dm_ptr[ib] = x[ib].d;
3136+
})
3137+
.wait_and_throw();
3138+
3139+
sycl::free(tmp_buf, *stream);
3140+
}
3141+
30953142
static void reorder_qw(const ggml_tensor * src0, dpct::queue_ptr stream) {
30963143
uint8_t * data_device = (uint8_t *) src0->data;
30973144
size_t ncols = src0->ne[0];
@@ -3105,6 +3152,9 @@ static void reorder_qw(const ggml_tensor * src0, dpct::queue_ptr stream) {
31053152
case GGML_TYPE_Q4_K:
31063153
reorder_qw_q4_k(data_device, size, 0, stream);
31073154
break;
3155+
case GGML_TYPE_Q6_K:
3156+
reorder_qw_q6_k(data_device, size, 0, stream);
3157+
break;
31083158
default:
31093159
GGML_ABORT("reorder_qw() called with unsupported type");
31103160
break;

ggml/src/ggml-sycl/mmvq.cpp

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,10 @@ static void mul_mat_vec_q_reorder(const void * __restrict__ vx, const void * __r
3131

3232
float partial_sum = 0.0f;
3333
for (int i = sg.get_local_linear_id() / block_elements_per_subgroup; i < blocks_per_row; i += blocks_per_subgroup) {
34-
const int ibx = row * blocks_per_row + i; // x block index
35-
// TODO: Generalize offsets, right now only works for quantizations that don't split high and low bits
36-
const int bx_offset = block_type::get_block_offset(ibx);
37-
const int d_offset = block_type::get_d_offset(nrows, ncols, ibx);
34+
const int ibx = row * blocks_per_row + i; // x block index
3835

36+
const auto bx_offset = block_type::get_block_offset(ibx, nblocks);
37+
const auto d_offset = block_type::get_d_offset(nrows, ncols, ibx);
3938
// Y block index that aligns with ibx
4039
const int iby = i * block_type::block_to_q8_1_ratio();
4140
const int8_t* q8_1_quant_ptr = (const int8_t*)vy + iby * QK8_1;
@@ -46,7 +45,7 @@ static void mul_mat_vec_q_reorder(const void * __restrict__ vx, const void * __r
4645
// x block quant index when casting the quants to int
4746
const int iqs = elem + block_traits::vdr_mmvq * (sg.get_local_linear_id() % block_elements_per_subgroup);
4847

49-
partial_sum += reorder_vec_dot_q_sycl()(vx, bx_offset, d_offset, q8_1_quant_ptr, q8_1_ds_ptr, iqs, nblocks);
48+
partial_sum += reorder_vec_dot_q_sycl()(vx, bx_offset, d_offset, q8_1_quant_ptr, q8_1_ds_ptr, iqs);
5049
}
5150
}
5251

@@ -785,6 +784,24 @@ static void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy,
785784
}
786785
}
787786

787+
static void reorder_mul_mat_vec_q6_k_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols,
788+
const int nrows, dpct::queue_ptr stream) {
789+
GGML_ASSERT(ncols % QK_K == 0);
790+
const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y);
791+
constexpr size_t num_subgroups = 16;
792+
GGML_ASSERT(block_num_y % num_subgroups == 0);
793+
794+
const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE);
795+
const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
796+
797+
stream->submit([&](sycl::handler & cgh) {
798+
cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size),
799+
[=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
800+
mul_mat_vec_q_reorder<reorder_vec_dot_q_sycl<GGML_TYPE_Q6_K>>(vx, vy, dst, ncols, nrows,
801+
nd_item);
802+
});
803+
});
804+
}
788805
static void mul_mat_vec_q6_K_q8_1_sycl(const void *vx, const void *vy,
789806
float *dst, const int ncols,
790807
const int nrows,
@@ -1070,7 +1087,14 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens
10701087
mul_mat_vec_q5_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
10711088
break;
10721089
case GGML_TYPE_Q6_K:
1073-
mul_mat_vec_q6_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1090+
if ((ggml_tensor_extra_gpu *) dst->src[0]->extra &&
1091+
((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
1092+
GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q6_k_q8_1_sycl\n");
1093+
reorder_mul_mat_vec_q6_k_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1094+
} else {
1095+
GGML_SYCL_DEBUG("Calling mul_mat_vec_q6_k_q8_1_sycl\n");
1096+
mul_mat_vec_q6_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1097+
}
10741098
break;
10751099
case GGML_TYPE_IQ1_S:
10761100
mul_mat_vec_iq1_s_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);

ggml/src/ggml-sycl/quants.hpp

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,13 @@
1414
#ifndef GGML_SYCL_QUANTS_HPP
1515
#define GGML_SYCL_QUANTS_HPP
1616

17+
#include <utility>
18+
1719
#include "ggml-common.h"
1820
#include "ggml.h"
1921

2022
namespace ggml_sycl_reordered {
2123

22-
2324
// The reordered block moves quants (qs) and scales(d) to two
2425
// uniform regions of memory that is contiguous in the same tensor.
2526
// What this means is that instead of having:
@@ -32,7 +33,6 @@ namespace ggml_sycl_reordered {
3233

3334
template <ggml_type type> struct block_q_t;
3435

35-
3636
// qk number of weights / quants in a block
3737
// qr number of weights in a byte (described as 'before dequantization')
3838
// for quantization types that has low and high bits split, qr is calculated with
@@ -47,10 +47,12 @@ template <> struct block_q_t<GGML_TYPE_Q4_0> {
4747
static constexpr uint32_t vdr_mmvq = 2;
4848
};
4949

50-
static constexpr int get_block_offset(const int block_index) { return block_index * (traits::qk / traits::qr); }
50+
static constexpr std::pair<int, int> get_block_offset(const int block_index, const int /* nblocks */) {
51+
return { block_index * (traits::qk / traits::qr), 0 };
52+
}
5153

52-
static constexpr int get_d_offset(int nrows, int ncols, const int block_index) {
53-
return (ncols / traits::qr * nrows) + block_index * sizeof(ggml_half);
54+
static constexpr std::pair<int, int> get_d_offset(int nrows, int ncols, const int block_index) {
55+
return { (ncols / traits::qr * nrows) + block_index * sizeof(ggml_half), 0 };
5456
}
5557

5658
static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; }
@@ -64,20 +66,46 @@ template <> struct block_q_t<GGML_TYPE_Q4_K> {
6466
static constexpr uint32_t vdr_mmvq = 2;
6567
};
6668

67-
static constexpr int get_block_offset(const int block_index) { return block_index * (traits::qk / traits::qr); }
69+
static constexpr std::pair<int, int> get_block_offset(const int block_index, const int /* nblocks */) {
70+
return { block_index * (traits::qk / traits::qr), 0 };
71+
}
6872

69-
static constexpr int get_d_offset(int nrows, int ncols, const int block_index) {
73+
static constexpr std::pair<int, int> get_d_offset(int nrows, int ncols, const int block_index) {
7074
auto nblocks = (nrows * (ncols / traits::qk));
71-
return (nblocks * QK_K / 2) + (nblocks * K_SCALE_SIZE) + (block_index * sizeof(ggml_half2));
75+
return { nblocks * (QK_K / 2),
76+
(nblocks * QK_K / 2) + (nblocks * K_SCALE_SIZE) + (block_index * sizeof(ggml_half2)) };
7277
}
7378

7479
static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; }
7580

7681
constexpr size_t get_total_qs_bytes(int nblocks) { return nblocks * QK_K / 2; }
77-
78-
constexpr size_t get_dm_offset(int nblocks) { return get_total_qs_bytes(nblocks) + nblocks * K_SCALE_SIZE; }
7982
};
8083

84+
template <> struct block_q_t<GGML_TYPE_Q6_K> {
85+
struct traits {
86+
static constexpr uint32_t qk = QK_K;
87+
static constexpr uint32_t qi = QI6_K;
88+
static constexpr uint32_t qr = QR6_K;
89+
static constexpr uint32_t vdr_mmvq = 1;
90+
};
91+
92+
static constexpr std::pair<int, int> get_block_offset(const int block_index, const int n_blocks) {
93+
auto low_bits_index = block_index * (traits::qk / traits::qr);
94+
// the index of high bits it's after all low bits
95+
auto high_bits_index = n_blocks * (QK_K / 2) + (block_index * (QK_K / 4));
96+
return { low_bits_index, high_bits_index };
97+
}
98+
99+
static constexpr std::pair<int, int> get_d_offset(int nrows, int ncols, const int block_index) {
100+
auto nblocks = (nrows * (ncols / traits::qk));
101+
auto total_qs_bytes = nblocks * (QK_K / 2) + nblocks * (QK_K / 4);
102+
auto block_scales = total_qs_bytes + block_index * (QK_K / 16);
103+
auto sb_scale = total_qs_bytes + nblocks * (QK_K / 16);
104+
return { block_scales, sb_scale };
105+
}
106+
107+
static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; }
108+
};
81109
} // namespace ggml_sycl_reordered
82110

83111
#endif // GGML_SYCL_QUANTS_HPP

0 commit comments

Comments
 (0)