Skip to content

Dequant improvements rebase #8255

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions ggml/src/ggml-sycl/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -351,4 +351,10 @@ static __dpct_inline__ float warp_reduce_max(float x,
return x;
}

// Helper for vec loading aligned data
template <typename Tp, int n>
inline sycl::vec<Tp, n> vec_aligned_load(const Tp* aligned_ptr) {
return *reinterpret_cast<const sycl::vec<Tp, n>*>(aligned_ptr);
}

#endif // GGML_SYCL_COMMON_HPP
7 changes: 5 additions & 2 deletions ggml/src/ggml-sycl/convert.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,12 +152,15 @@ static void dequantize_row_q4_K_sycl(const void *vx, dst_t *y, const int k,
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});

stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
stream->submit([&](sycl::handler &cgh) {
sycl::local_accessor<uint8_t, 1> scale_local_acc(sycl::range<1>(12), cgh);
cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
sycl::range<3>(1, 1, 32),
sycl::range<3>(1, 1, 32)),
[=](sycl::nd_item<3> item_ct1) {
dequantize_block_q4_K(vx, y, item_ct1);
dequantize_block_q4_K(vx, y, scale_local_acc.get_pointer(), item_ct1);
});
});
}
}

Expand Down
30 changes: 19 additions & 11 deletions ggml/src/ggml-sycl/dequantize.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,8 @@ static void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restri
#if QK_K == 256
static inline void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) {
if (j < 4) {
d = q[j] & 63; m = q[j + 4] & 63;
d = q[j] & 63;
m = q[j + 4] & 63;
} else {
d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
Expand All @@ -303,7 +304,7 @@ static inline void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8

template<typename dst_t>
static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restrict__ yy,
const sycl::nd_item<3> &item_ct1) {
uint8_t* scales_local, const sycl::nd_item<3> &item_ct1) {
const block_q4_K * x = (const block_q4_K *) vx;

const int i = item_ct1.get_group(2);
Expand All @@ -318,19 +319,26 @@ static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restri

dst_t * y = yy + i*QK_K + 64*il + n*ir;

const float dall = x[i].dm[0];
const float dmin = x[i].dm[1];
const sycl::half2 dm = x[i].dm;
const float dall = dm[0];
const float dmin = dm[1];

const uint8_t * q = x[i].qs + 32*il + n*ir;
if (tid < 12)
scales_local[tid] = x[i].scales[tid];
item_ct1.barrier(sycl::access::fence_space::local_space);

uint8_t sc, m;
get_scale_min_k4(is + 0, x[i].scales, sc, m);
const float d1 = dall * sc; const float m1 = dmin * m;
get_scale_min_k4(is + 1, x[i].scales, sc, m);
const float d2 = dall * sc; const float m2 = dmin * m;
get_scale_min_k4(is + 0, scales_local, sc, m);
const float d1 = dall * sc;
const float m1 = dmin * m;
get_scale_min_k4(is + 1, scales_local, sc, m);
const float d2 = dall * sc;
const float m2 = dmin * m;

sycl::vec<uint8_t, n> q_vec = vec_aligned_load<uint8_t, n>(x[i].qs + 32*il + n*ir);
for (int l = 0; l < n; ++l) {
y[l + 0] = d1 * (q[l] & 0xF) - m1;
y[l +32] = d2 * (q[l] >> 4) - m2;
y[l + 0] = d1 * (q_vec[l] & 0xF) - m1;
y[l +32] = d2 * (q_vec[l] >> 4) - m2;
}
#else
const int tid = item_ct1.get_local_id(2);
Expand Down
Loading