Skip to content

Commit fadde67

Browse files
author
AidanBeltonS
authored
Dequant improvements rebase (#8255)
* Single load for half2 * Store scales in local mem * Vec load quantized values
1 parent a27152b commit fadde67

File tree

3 files changed

+30
-13
lines changed

3 files changed

+30
-13
lines changed

ggml/src/ggml-sycl/common.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,4 +351,10 @@ static __dpct_inline__ float warp_reduce_max(float x,
351351
return x;
352352
}
353353

354+
// Helper for vec loading aligned data
355+
template <typename Tp, int n>
356+
inline sycl::vec<Tp, n> vec_aligned_load(const Tp* aligned_ptr) {
357+
return *reinterpret_cast<const sycl::vec<Tp, n>*>(aligned_ptr);
358+
}
359+
354360
#endif // GGML_SYCL_COMMON_HPP

ggml/src/ggml-sycl/convert.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,12 +152,15 @@ static void dequantize_row_q4_K_sycl(const void *vx, dst_t *y, const int k,
152152
dpct::has_capability_or_fail(stream->get_device(),
153153
{sycl::aspect::fp16});
154154

155-
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
155+
stream->submit([&](sycl::handler &cgh) {
156+
sycl::local_accessor<uint8_t, 1> scale_local_acc(sycl::range<1>(12), cgh);
157+
cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
156158
sycl::range<3>(1, 1, 32),
157159
sycl::range<3>(1, 1, 32)),
158160
[=](sycl::nd_item<3> item_ct1) {
159-
dequantize_block_q4_K(vx, y, item_ct1);
161+
dequantize_block_q4_K(vx, y, scale_local_acc.get_pointer(), item_ct1);
160162
});
163+
});
161164
}
162165
}
163166

ggml/src/ggml-sycl/dequantize.hpp

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,8 @@ static void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restri
293293
#if QK_K == 256
294294
static inline void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) {
295295
if (j < 4) {
296-
d = q[j] & 63; m = q[j + 4] & 63;
296+
d = q[j] & 63;
297+
m = q[j + 4] & 63;
297298
} else {
298299
d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
299300
m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
@@ -303,7 +304,7 @@ static inline void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8
303304

304305
template<typename dst_t>
305306
static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restrict__ yy,
306-
const sycl::nd_item<3> &item_ct1) {
307+
uint8_t* scales_local, const sycl::nd_item<3> &item_ct1) {
307308
const block_q4_K * x = (const block_q4_K *) vx;
308309

309310
const int i = item_ct1.get_group(2);
@@ -318,19 +319,26 @@ static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restri
318319

319320
dst_t * y = yy + i*QK_K + 64*il + n*ir;
320321

321-
const float dall = x[i].dm[0];
322-
const float dmin = x[i].dm[1];
322+
const sycl::half2 dm = x[i].dm;
323+
const float dall = dm[0];
324+
const float dmin = dm[1];
323325

324-
const uint8_t * q = x[i].qs + 32*il + n*ir;
326+
if (tid < 12)
327+
scales_local[tid] = x[i].scales[tid];
328+
item_ct1.barrier(sycl::access::fence_space::local_space);
325329

326330
uint8_t sc, m;
327-
get_scale_min_k4(is + 0, x[i].scales, sc, m);
328-
const float d1 = dall * sc; const float m1 = dmin * m;
329-
get_scale_min_k4(is + 1, x[i].scales, sc, m);
330-
const float d2 = dall * sc; const float m2 = dmin * m;
331+
get_scale_min_k4(is + 0, scales_local, sc, m);
332+
const float d1 = dall * sc;
333+
const float m1 = dmin * m;
334+
get_scale_min_k4(is + 1, scales_local, sc, m);
335+
const float d2 = dall * sc;
336+
const float m2 = dmin * m;
337+
338+
sycl::vec<uint8_t, n> q_vec = vec_aligned_load<uint8_t, n>(x[i].qs + 32*il + n*ir);
331339
for (int l = 0; l < n; ++l) {
332-
y[l + 0] = d1 * (q[l] & 0xF) - m1;
333-
y[l +32] = d2 * (q[l] >> 4) - m2;
340+
y[l + 0] = d1 * (q_vec[l] & 0xF) - m1;
341+
y[l +32] = d2 * (q_vec[l] >> 4) - m2;
334342
}
335343
#else
336344
const int tid = item_ct1.get_local_id(2);

0 commit comments

Comments
 (0)