3
3
#include " dequantize.hpp"
4
4
#include " presets.hpp"
5
5
6
+ int constexpr QK_WARP_SIZE = 32 ;
7
+
6
8
static void convert_f16 (const void * vx, const int ib, const int iqs, dfloat2 & v){
7
9
const sycl::half *x = (const sycl::half *)vx;
8
10
@@ -227,7 +229,7 @@ static void dequantize_mul_mat_vec_q2_k(const void *__restrict__ vx,
227
229
228
230
// sum up partial sums and write back result
229
231
#pragma unroll
230
- for (int mask = WARP_SIZE / 2 ; mask > 0 ; mask >>= 1 ) {
232
+ for (int mask = QK_WARP_SIZE / 2 ; mask > 0 ; mask >>= 1 ) {
231
233
tmp +=
232
234
dpct::permute_sub_group_by_xor (item_ct1.get_sub_group (), tmp, mask);
233
235
}
@@ -346,7 +348,7 @@ static void dequantize_mul_mat_vec_q3_k(const void *__restrict__ vx,
346
348
347
349
// sum up partial sums and write back result
348
350
#pragma unroll
349
- for (int mask = WARP_SIZE / 2 ; mask > 0 ; mask >>= 1 ) {
351
+ for (int mask = QK_WARP_SIZE / 2 ; mask > 0 ; mask >>= 1 ) {
350
352
tmp +=
351
353
dpct::permute_sub_group_by_xor (item_ct1.get_sub_group (), tmp, mask);
352
354
}
@@ -499,7 +501,7 @@ static void dequantize_mul_mat_vec_q4_k(const void *__restrict__ vx,
499
501
500
502
// sum up partial sums and write back result
501
503
#pragma unroll
502
- for (int mask = WARP_SIZE / 2 ; mask > 0 ; mask >>= 1 ) {
504
+ for (int mask = QK_WARP_SIZE / 2 ; mask > 0 ; mask >>= 1 ) {
503
505
tmp +=
504
506
dpct::permute_sub_group_by_xor (item_ct1.get_sub_group (), tmp, mask);
505
507
}
@@ -633,7 +635,7 @@ static void dequantize_mul_mat_vec_q5_k(const void *__restrict__ vx,
633
635
634
636
// sum up partial sums and write back result
635
637
#pragma unroll
636
- for (int mask = WARP_SIZE / 2 ; mask > 0 ; mask >>= 1 ) {
638
+ for (int mask = QK_WARP_SIZE / 2 ; mask > 0 ; mask >>= 1 ) {
637
639
tmp +=
638
640
dpct::permute_sub_group_by_xor (item_ct1.get_sub_group (), tmp, mask);
639
641
}
@@ -748,7 +750,7 @@ static void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const floa
748
750
749
751
// sum up partial sums and write back result
750
752
#pragma unroll
751
- for (int mask = WARP_SIZE / 2 ; mask > 0 ; mask >>= 1 ) {
753
+ for (int mask = QK_WARP_SIZE / 2 ; mask > 0 ; mask >>= 1 ) {
752
754
tmp +=
753
755
dpct::permute_sub_group_by_xor (item_ct1.get_sub_group (), tmp, mask);
754
756
}
@@ -873,10 +875,10 @@ static void dequantize_mul_mat_vec_q2_K_sycl(const void *vx, const float *y,
873
875
const int ny = 2 ; // very slightly faster than 1 even when K_QUANTS_PER_ITERATION = 2
874
876
const int block_num_y = (nrows + ny - 1 ) / ny;
875
877
const sycl::range<3 > block_nums (1 , 1 , block_num_y);
876
- const sycl::range<3 > block_dims (1 , ny, WARP_SIZE );
878
+ const sycl::range<3 > block_dims (1 , ny, QK_WARP_SIZE );
877
879
stream->parallel_for (
878
880
sycl::nd_range<3 >(block_nums * block_dims, block_dims),
879
- [=](sycl::nd_item<3 > item_ct1) [[intel::reqd_sub_group_size (WARP_SIZE )]] {
881
+ [=](sycl::nd_item<3 > item_ct1) [[intel::reqd_sub_group_size (QK_WARP_SIZE )]] {
880
882
dequantize_mul_mat_vec_q2_k (vx, y, dst, ncols, nrows, item_ct1);
881
883
});
882
884
}
@@ -889,10 +891,10 @@ static void dequantize_mul_mat_vec_q3_K_sycl(const void *vx, const float *y,
889
891
const int ny = 2 / K_QUANTS_PER_ITERATION;
890
892
const int block_num_y = (nrows + ny - 1 ) / ny;
891
893
const sycl::range<3 > block_nums (1 , 1 , block_num_y);
892
- const sycl::range<3 > block_dims (1 , ny, WARP_SIZE );
894
+ const sycl::range<3 > block_dims (1 , ny, QK_WARP_SIZE );
893
895
stream->parallel_for (
894
896
sycl::nd_range<3 >(block_nums * block_dims, block_dims),
895
- [=](sycl::nd_item<3 > item_ct1) [[intel::reqd_sub_group_size (WARP_SIZE )]] {
897
+ [=](sycl::nd_item<3 > item_ct1) [[intel::reqd_sub_group_size (QK_WARP_SIZE )]] {
896
898
dequantize_mul_mat_vec_q3_k (vx, y, dst, ncols, nrows, item_ct1);
897
899
});
898
900
}
@@ -905,10 +907,10 @@ static void dequantize_mul_mat_vec_q4_K_sycl(const void *vx, const float *y,
905
907
const int ny = 2 / K_QUANTS_PER_ITERATION;
906
908
const int block_num_y = (nrows + ny - 1 ) / ny;
907
909
const sycl::range<3 > block_nums (1 , 1 , block_num_y);
908
- const sycl::range<3 > block_dims (1 , ny, WARP_SIZE );
910
+ const sycl::range<3 > block_dims (1 , ny, QK_WARP_SIZE );
909
911
stream->parallel_for (
910
912
sycl::nd_range<3 >(block_nums * block_dims, block_dims),
911
- [=](sycl::nd_item<3 > item_ct1) [[intel::reqd_sub_group_size (WARP_SIZE )]] {
913
+ [=](sycl::nd_item<3 > item_ct1) [[intel::reqd_sub_group_size (QK_WARP_SIZE )]] {
912
914
dequantize_mul_mat_vec_q4_k (vx, y, dst, ncols, nrows, item_ct1);
913
915
});
914
916
}
@@ -918,10 +920,10 @@ static void dequantize_mul_mat_vec_q5_K_sycl(const void *vx, const float *y,
918
920
const int nrows,
919
921
dpct::queue_ptr stream) {
920
922
GGML_ASSERT (ncols % QK_K == 0 );
921
- const sycl::range<3 > block_dims (1 , 1 , WARP_SIZE );
923
+ const sycl::range<3 > block_dims (1 , 1 , QK_WARP_SIZE );
922
924
stream->parallel_for (
923
925
sycl::nd_range<3 >(sycl::range<3 >(1 , 1 , nrows) * block_dims, block_dims),
924
- [=](sycl::nd_item<3 > item_ct1) [[intel::reqd_sub_group_size (WARP_SIZE )]] {
926
+ [=](sycl::nd_item<3 > item_ct1) [[intel::reqd_sub_group_size (QK_WARP_SIZE )]] {
925
927
dequantize_mul_mat_vec_q5_k (vx, y, dst, ncols, item_ct1);
926
928
});
927
929
}
@@ -934,10 +936,10 @@ static void dequantize_mul_mat_vec_q6_K_sycl(const void *vx, const float *y,
934
936
const int ny = 2 / K_QUANTS_PER_ITERATION;
935
937
const int block_num_y = (nrows + ny - 1 ) / ny;
936
938
const sycl::range<3 > block_nums (1 , 1 , block_num_y);
937
- const sycl::range<3 > block_dims (1 , ny, WARP_SIZE );
939
+ const sycl::range<3 > block_dims (1 , ny, QK_WARP_SIZE );
938
940
stream->parallel_for (
939
941
sycl::nd_range<3 >(block_nums * block_dims, block_dims),
940
- [=](sycl::nd_item<3 > item_ct1) [[intel::reqd_sub_group_size (WARP_SIZE )]] {
942
+ [=](sycl::nd_item<3 > item_ct1) [[intel::reqd_sub_group_size (QK_WARP_SIZE )]] {
941
943
dequantize_mul_mat_vec_q6_k (vx, y, dst, ncols, nrows, item_ct1);
942
944
});
943
945
}
0 commit comments