Skip to content

Commit b9111bd

Browse files
Update ggml_sycl_op_mul_mat_vec_q (#5502)
* Update ggml_sycl_op_mul_mat_vec_q * Apply suggestions from code review Co-authored-by: Abhilash Majumder <[email protected]> * revert suggestion on macro * fix bug * Add quant type GGML_TYPE_IQ1_S to unsupported * fix format --------- Co-authored-by: Abhilash Majumder <[email protected]>
1 parent 633782b commit b9111bd

File tree

1 file changed

+69
-189
lines changed

1 file changed

+69
-189
lines changed

ggml-sycl.cpp

Lines changed: 69 additions & 189 deletions
Original file line numberDiff line numberDiff line change
@@ -9188,174 +9188,22 @@ static void convert_mul_mat_vec_f16_sycl(const void *vx, const dfloat *y,
91889188
}
91899189
}
91909190

9191-
static void mul_mat_vec_q4_0_q8_1_sycl(const void *vx, const void *vy,
9192-
float *dst, const int ncols,
9193-
const int nrows,
9194-
dpct::queue_ptr stream) {
9195-
GGML_ASSERT(ncols % QK4_0 == 0);
9196-
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
9197-
const sycl::range<3> block_nums(1, 1, block_num_y);
9198-
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
9199-
stream->parallel_for(
9200-
sycl::nd_range<3>(block_nums * block_dims, block_dims),
9201-
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
9202-
mul_mat_vec_q<QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ,
9203-
vec_dot_q4_0_q8_1>(vx, vy, dst, ncols, nrows,
9204-
item_ct1);
9205-
});
9206-
}
9207-
9208-
static void mul_mat_vec_q4_1_q8_1_sycl(const void *vx, const void *vy,
9209-
float *dst, const int ncols,
9210-
const int nrows,
9211-
dpct::queue_ptr stream) {
9212-
GGML_ASSERT(ncols % QK4_1 == 0);
9213-
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
9214-
const sycl::range<3> block_nums(1, 1, block_num_y);
9215-
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
9216-
stream->parallel_for(
9217-
sycl::nd_range<3>(block_nums * block_dims, block_dims),
9218-
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
9219-
mul_mat_vec_q<QK4_0, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ,
9220-
vec_dot_q4_1_q8_1>(vx, vy, dst, ncols, nrows,
9221-
item_ct1);
9222-
});
9223-
}
9224-
9225-
static void mul_mat_vec_q5_0_q8_1_sycl(const void *vx, const void *vy,
9226-
float *dst, const int ncols,
9227-
const int nrows,
9228-
dpct::queue_ptr stream) {
9229-
GGML_ASSERT(ncols % QK5_0 == 0);
9230-
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
9231-
const sycl::range<3> block_nums(1, 1, block_num_y);
9232-
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
9233-
stream->parallel_for(
9234-
sycl::nd_range<3>(block_nums * block_dims, block_dims),
9235-
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
9236-
mul_mat_vec_q<QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ,
9237-
vec_dot_q5_0_q8_1>(vx, vy, dst, ncols, nrows,
9238-
item_ct1);
9239-
});
9240-
}
9241-
9242-
static void mul_mat_vec_q5_1_q8_1_sycl(const void *vx, const void *vy,
9243-
float *dst, const int ncols,
9244-
const int nrows,
9245-
dpct::queue_ptr stream) {
9246-
GGML_ASSERT(ncols % QK5_1 == 0);
9247-
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
9248-
const sycl::range<3> block_nums(1, 1, block_num_y);
9249-
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
9250-
stream->parallel_for(
9251-
sycl::nd_range<3>(block_nums * block_dims, block_dims),
9252-
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
9253-
mul_mat_vec_q<QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ,
9254-
vec_dot_q5_1_q8_1>(vx, vy, dst, ncols, nrows,
9255-
item_ct1);
9256-
});
9257-
}
9258-
9259-
static void mul_mat_vec_q8_0_q8_1_sycl(const void *vx, const void *vy,
9260-
float *dst, const int ncols,
9261-
const int nrows,
9262-
dpct::queue_ptr stream) {
9263-
GGML_ASSERT(ncols % QK8_0 == 0);
9264-
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
9265-
const sycl::range<3> block_nums(1, 1, block_num_y);
9266-
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
9267-
stream->parallel_for(
9268-
sycl::nd_range<3>(block_nums * block_dims, block_dims),
9269-
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
9270-
mul_mat_vec_q<QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ,
9271-
vec_dot_q8_0_q8_1>(vx, vy, dst, ncols, nrows,
9272-
item_ct1);
9273-
});
9274-
}
9275-
9276-
static void mul_mat_vec_q2_K_q8_1_sycl(const void *vx, const void *vy,
9277-
float *dst, const int ncols,
9278-
const int nrows,
9279-
dpct::queue_ptr stream) {
9280-
GGML_ASSERT(ncols % QK_K == 0);
9281-
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
9282-
const sycl::range<3> block_nums(1, 1, block_num_y);
9283-
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
9284-
stream->parallel_for(
9285-
sycl::nd_range<3>(block_nums * block_dims, block_dims),
9286-
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
9287-
mul_mat_vec_q<QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ,
9288-
vec_dot_q2_K_q8_1>(vx, vy, dst, ncols, nrows,
9289-
item_ct1);
9290-
});
9291-
}
9292-
9293-
static void mul_mat_vec_q3_K_q8_1_sycl(const void *vx, const void *vy,
9294-
float *dst, const int ncols,
9295-
const int nrows,
9296-
dpct::queue_ptr stream) {
9297-
GGML_ASSERT(ncols % QK_K == 0);
9298-
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
9299-
const sycl::range<3> block_nums(1, 1, block_num_y);
9300-
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
9301-
stream->parallel_for(
9302-
sycl::nd_range<3>(block_nums * block_dims, block_dims),
9303-
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
9304-
mul_mat_vec_q<QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ,
9305-
vec_dot_q3_K_q8_1>(vx, vy, dst, ncols, nrows,
9306-
item_ct1);
9307-
});
9308-
}
9309-
9310-
static void mul_mat_vec_q4_K_q8_1_sycl(const void *vx, const void *vy,
9311-
float *dst, const int ncols,
9312-
const int nrows,
9313-
dpct::queue_ptr stream) {
9314-
GGML_ASSERT(ncols % QK_K == 0);
9315-
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
9316-
const sycl::range<3> block_nums(1, 1, block_num_y);
9317-
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
9318-
stream->parallel_for(
9319-
sycl::nd_range<3>(block_nums * block_dims, block_dims),
9320-
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
9321-
mul_mat_vec_q<QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ,
9322-
vec_dot_q4_K_q8_1>(vx, vy, dst, ncols, nrows,
9323-
item_ct1);
9324-
});
9325-
}
9326-
9327-
static void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy,
9328-
float *dst, const int ncols,
9329-
const int nrows,
9330-
dpct::queue_ptr stream) {
9331-
GGML_ASSERT(ncols % QK_K == 0);
9332-
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
9333-
const sycl::range<3> block_nums(1, 1, block_num_y);
9334-
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
9335-
stream->parallel_for(
9336-
sycl::nd_range<3>(block_nums * block_dims, block_dims),
9337-
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
9338-
mul_mat_vec_q<QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ,
9339-
vec_dot_q5_K_q8_1>(vx, vy, dst, ncols, nrows,
9340-
item_ct1);
9341-
});
9342-
}
9343-
9344-
static void mul_mat_vec_q6_K_q8_1_sycl(const void *vx, const void *vy,
9345-
float *dst, const int ncols,
9346-
const int nrows,
9347-
dpct::queue_ptr stream) {
9348-
GGML_ASSERT(ncols % QK_K == 0);
9349-
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
9350-
const sycl::range<3> block_nums(1, 1, block_num_y);
9351-
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
9352-
stream->parallel_for(
9353-
sycl::nd_range<3>(block_nums * block_dims, block_dims),
9354-
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
9355-
mul_mat_vec_q<QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ,
9356-
vec_dot_q6_K_q8_1>(vx, vy, dst, ncols, nrows,
9357-
item_ct1);
9358-
});
9191+
template <int qk, int qi, typename block_q_t, int vdr,
9192+
vec_dot_q_sycl_t vec_dot_q_sycl>
9193+
static void mul_mat_vec_q_sycl_submitter(const void *vx, const void *vy,
9194+
float *dst, const int ncols,
9195+
const int nrows,
9196+
dpct::queue_ptr stream) {
9197+
GGML_ASSERT(ncols % QK4_0 == 0);
9198+
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
9199+
const sycl::range<3> block_nums(1, 1, block_num_y);
9200+
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
9201+
stream->parallel_for(
9202+
sycl::nd_range<3>(block_nums * block_dims, block_dims), [=
9203+
](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
9204+
mul_mat_vec_q<qk, qi, block_q_t, vdr, vec_dot_q_sycl>(
9205+
vx, vy, dst, ncols, nrows, item_ct1);
9206+
});
93599207
}
93609208

93619209
int get_device_index_by_id(int id){
@@ -12095,37 +11943,63 @@ inline void ggml_sycl_op_mul_mat_vec_q(
1209511943
const int64_t ne00 = src0->ne[0];
1209611944
const int64_t row_diff = row_high - row_low;
1209711945

11946+
// TODO: support these quantization types
11947+
GGML_ASSERT(!(src0->type == GGML_TYPE_IQ2_XXS ||
11948+
src0->type == GGML_TYPE_IQ2_XS ||
11949+
src0->type == GGML_TYPE_IQ3_XXS ||
11950+
src0->type == GGML_TYPE_IQ1_S));
11951+
1209811952
switch (src0->type) {
1209911953
case GGML_TYPE_Q4_0:
12100-
mul_mat_vec_q4_0_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
12101-
break;
11954+
mul_mat_vec_q_sycl_submitter<QK4_0, QI4_0, block_q4_0,
11955+
VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>(
11956+
src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
11957+
break;
1210211958
case GGML_TYPE_Q4_1:
12103-
mul_mat_vec_q4_1_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
12104-
break;
11959+
mul_mat_vec_q_sycl_submitter<QK4_1, QI4_1, block_q4_1,
11960+
VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>(
11961+
src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
11962+
break;
1210511963
case GGML_TYPE_Q5_0:
12106-
mul_mat_vec_q5_0_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
12107-
break;
11964+
mul_mat_vec_q_sycl_submitter<QK5_0, QI5_0, block_q5_0,
11965+
VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>(
11966+
src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
11967+
break;
1210811968
case GGML_TYPE_Q5_1:
12109-
mul_mat_vec_q5_1_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
12110-
break;
11969+
mul_mat_vec_q_sycl_submitter<QK5_1, QI5_1, block_q5_1,
11970+
VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>(
11971+
src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
11972+
break;
1211111973
case GGML_TYPE_Q8_0:
12112-
mul_mat_vec_q8_0_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
12113-
break;
11974+
mul_mat_vec_q_sycl_submitter<QK8_0, QI8_0, block_q8_0,
11975+
VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>(
11976+
src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
11977+
break;
1211411978
case GGML_TYPE_Q2_K:
12115-
mul_mat_vec_q2_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
12116-
break;
11979+
mul_mat_vec_q_sycl_submitter<QK_K, QI2_K, block_q2_K,
11980+
VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>(
11981+
src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
11982+
break;
1211711983
case GGML_TYPE_Q3_K:
12118-
mul_mat_vec_q3_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
12119-
break;
11984+
mul_mat_vec_q_sycl_submitter<QK_K, QI3_K, block_q3_K,
11985+
VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>(
11986+
src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
11987+
break;
1212011988
case GGML_TYPE_Q4_K:
12121-
mul_mat_vec_q4_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
12122-
break;
11989+
mul_mat_vec_q_sycl_submitter<QK_K, QI4_K, block_q4_K,
11990+
VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>(
11991+
src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
11992+
break;
1212311993
case GGML_TYPE_Q5_K:
12124-
mul_mat_vec_q5_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
12125-
break;
11994+
mul_mat_vec_q_sycl_submitter<QK_K, QI5_K, block_q5_K,
11995+
VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>(
11996+
src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
11997+
break;
1212611998
case GGML_TYPE_Q6_K:
12127-
mul_mat_vec_q6_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
12128-
break;
11999+
mul_mat_vec_q_sycl_submitter<QK_K, QI6_K, block_q6_K,
12000+
VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>(
12001+
src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
12002+
break;
1212912003
default:
1213012004
GGML_ASSERT(false);
1213112005
break;
@@ -12145,7 +12019,7 @@ inline void ggml_sycl_op_dequantize_mul_mat_vec(
1214512019
const int64_t src1_ncols, const int64_t src1_padded_row_size,
1214612020
const dpct::queue_ptr &stream) {
1214712021

12148-
GGML_TENSOR_BINARY_OP_LOCALS
12022+
GGML_TENSOR_BINARY_OP_LOCALS;
1214912023

1215012024
const int64_t row_diff = row_high - row_low;
1215112025

@@ -15093,6 +14967,12 @@ static bool ggml_backend_sycl_supports_op(ggml_backend_t backend, const ggml_ten
1509314967
return false;
1509414968
}
1509514969

14970+
if (a->type == GGML_TYPE_IQ1_S) {
14971+
return false;
14972+
}
14973+
if (a->type == GGML_TYPE_IQ3_XXS) {
14974+
return false;
14975+
}
1509614976
if (a->type == GGML_TYPE_IQ2_XXS) {
1509714977
return false;
1509814978
}

0 commit comments

Comments
 (0)