@@ -293,7 +293,8 @@ static void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restri
293
293
#if QK_K == 256
294
294
static inline void get_scale_min_k4 (int j, const uint8_t * q, uint8_t & d, uint8_t & m) {
295
295
if (j < 4 ) {
296
- d = q[j] & 63 ; m = q[j + 4 ] & 63 ;
296
+ d = q[j] & 63 ;
297
+ m = q[j + 4 ] & 63 ;
297
298
} else {
298
299
d = (q[j+4 ] & 0xF ) | ((q[j-4 ] >> 6 ) << 4 );
299
300
m = (q[j+4 ] >> 4 ) | ((q[j-0 ] >> 6 ) << 4 );
@@ -303,7 +304,7 @@ static inline void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8
303
304
304
305
template <typename dst_t >
305
306
static void dequantize_block_q4_K (const void * __restrict__ vx, dst_t * __restrict__ yy,
306
- const sycl::nd_item<3 > &item_ct1) {
307
+ uint8_t * scales_local, const sycl::nd_item<3 > &item_ct1) {
307
308
const block_q4_K * x = (const block_q4_K *) vx;
308
309
309
310
const int i = item_ct1.get_group (2 );
@@ -318,19 +319,26 @@ static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restri
318
319
319
320
dst_t * y = yy + i*QK_K + 64 *il + n*ir;
320
321
321
- const float dall = x[i].dm [0 ];
322
- const float dmin = x[i].dm [1 ];
322
+ const sycl::half2 dm = x[i].dm ;
323
+ const float dall = dm[0 ];
324
+ const float dmin = dm[1 ];
323
325
324
- const uint8_t * q = x[i].qs + 32 *il + n*ir;
326
+ if (tid < 12 )
327
+ scales_local[tid] = x[i].scales [tid];
328
+ item_ct1.barrier (sycl::access::fence_space::local_space);
325
329
326
330
uint8_t sc, m;
327
- get_scale_min_k4 (is + 0 , x[i].scales , sc, m);
328
- const float d1 = dall * sc; const float m1 = dmin * m;
329
- get_scale_min_k4 (is + 1 , x[i].scales , sc, m);
330
- const float d2 = dall * sc; const float m2 = dmin * m;
331
+ get_scale_min_k4 (is + 0 , scales_local, sc, m);
332
+ const float d1 = dall * sc;
333
+ const float m1 = dmin * m;
334
+ get_scale_min_k4 (is + 1 , scales_local, sc, m);
335
+ const float d2 = dall * sc;
336
+ const float m2 = dmin * m;
337
+
338
+ sycl::vec<uint8_t , n> q_vec = vec_aligned_load<uint8_t , n>(x[i].qs + 32 *il + n*ir);
331
339
for (int l = 0 ; l < n; ++l) {
332
- y[l + 0 ] = d1 * (q [l] & 0xF ) - m1;
333
- y[l +32 ] = d2 * (q [l] >> 4 ) - m2;
340
+ y[l + 0 ] = d1 * (q_vec [l] & 0xF ) - m1;
341
+ y[l +32 ] = d2 * (q_vec [l] >> 4 ) - m2;
334
342
}
335
343
#else
336
344
const int tid = item_ct1.get_local_id (2 );
0 commit comments