@@ -9188,174 +9188,22 @@ static void convert_mul_mat_vec_f16_sycl(const void *vx, const dfloat *y,
9188
9188
}
9189
9189
}
9190
9190
9191
- static void mul_mat_vec_q4_0_q8_1_sycl(const void *vx, const void *vy,
9192
- float *dst, const int ncols,
9193
- const int nrows,
9194
- dpct::queue_ptr stream) {
9195
- GGML_ASSERT(ncols % QK4_0 == 0);
9196
- const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
9197
- const sycl::range<3> block_nums(1, 1, block_num_y);
9198
- const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
9199
- stream->parallel_for(
9200
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
9201
- [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
9202
- mul_mat_vec_q<QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ,
9203
- vec_dot_q4_0_q8_1>(vx, vy, dst, ncols, nrows,
9204
- item_ct1);
9205
- });
9206
- }
9207
-
9208
- static void mul_mat_vec_q4_1_q8_1_sycl(const void *vx, const void *vy,
9209
- float *dst, const int ncols,
9210
- const int nrows,
9211
- dpct::queue_ptr stream) {
9212
- GGML_ASSERT(ncols % QK4_1 == 0);
9213
- const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
9214
- const sycl::range<3> block_nums(1, 1, block_num_y);
9215
- const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
9216
- stream->parallel_for(
9217
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
9218
- [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
9219
- mul_mat_vec_q<QK4_0, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ,
9220
- vec_dot_q4_1_q8_1>(vx, vy, dst, ncols, nrows,
9221
- item_ct1);
9222
- });
9223
- }
9224
-
9225
- static void mul_mat_vec_q5_0_q8_1_sycl(const void *vx, const void *vy,
9226
- float *dst, const int ncols,
9227
- const int nrows,
9228
- dpct::queue_ptr stream) {
9229
- GGML_ASSERT(ncols % QK5_0 == 0);
9230
- const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
9231
- const sycl::range<3> block_nums(1, 1, block_num_y);
9232
- const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
9233
- stream->parallel_for(
9234
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
9235
- [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
9236
- mul_mat_vec_q<QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ,
9237
- vec_dot_q5_0_q8_1>(vx, vy, dst, ncols, nrows,
9238
- item_ct1);
9239
- });
9240
- }
9241
-
9242
- static void mul_mat_vec_q5_1_q8_1_sycl(const void *vx, const void *vy,
9243
- float *dst, const int ncols,
9244
- const int nrows,
9245
- dpct::queue_ptr stream) {
9246
- GGML_ASSERT(ncols % QK5_1 == 0);
9247
- const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
9248
- const sycl::range<3> block_nums(1, 1, block_num_y);
9249
- const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
9250
- stream->parallel_for(
9251
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
9252
- [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
9253
- mul_mat_vec_q<QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ,
9254
- vec_dot_q5_1_q8_1>(vx, vy, dst, ncols, nrows,
9255
- item_ct1);
9256
- });
9257
- }
9258
-
9259
- static void mul_mat_vec_q8_0_q8_1_sycl(const void *vx, const void *vy,
9260
- float *dst, const int ncols,
9261
- const int nrows,
9262
- dpct::queue_ptr stream) {
9263
- GGML_ASSERT(ncols % QK8_0 == 0);
9264
- const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
9265
- const sycl::range<3> block_nums(1, 1, block_num_y);
9266
- const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
9267
- stream->parallel_for(
9268
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
9269
- [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
9270
- mul_mat_vec_q<QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ,
9271
- vec_dot_q8_0_q8_1>(vx, vy, dst, ncols, nrows,
9272
- item_ct1);
9273
- });
9274
- }
9275
-
9276
- static void mul_mat_vec_q2_K_q8_1_sycl(const void *vx, const void *vy,
9277
- float *dst, const int ncols,
9278
- const int nrows,
9279
- dpct::queue_ptr stream) {
9280
- GGML_ASSERT(ncols % QK_K == 0);
9281
- const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
9282
- const sycl::range<3> block_nums(1, 1, block_num_y);
9283
- const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
9284
- stream->parallel_for(
9285
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
9286
- [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
9287
- mul_mat_vec_q<QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ,
9288
- vec_dot_q2_K_q8_1>(vx, vy, dst, ncols, nrows,
9289
- item_ct1);
9290
- });
9291
- }
9292
-
9293
- static void mul_mat_vec_q3_K_q8_1_sycl(const void *vx, const void *vy,
9294
- float *dst, const int ncols,
9295
- const int nrows,
9296
- dpct::queue_ptr stream) {
9297
- GGML_ASSERT(ncols % QK_K == 0);
9298
- const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
9299
- const sycl::range<3> block_nums(1, 1, block_num_y);
9300
- const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
9301
- stream->parallel_for(
9302
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
9303
- [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
9304
- mul_mat_vec_q<QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ,
9305
- vec_dot_q3_K_q8_1>(vx, vy, dst, ncols, nrows,
9306
- item_ct1);
9307
- });
9308
- }
9309
-
9310
- static void mul_mat_vec_q4_K_q8_1_sycl(const void *vx, const void *vy,
9311
- float *dst, const int ncols,
9312
- const int nrows,
9313
- dpct::queue_ptr stream) {
9314
- GGML_ASSERT(ncols % QK_K == 0);
9315
- const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
9316
- const sycl::range<3> block_nums(1, 1, block_num_y);
9317
- const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
9318
- stream->parallel_for(
9319
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
9320
- [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
9321
- mul_mat_vec_q<QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ,
9322
- vec_dot_q4_K_q8_1>(vx, vy, dst, ncols, nrows,
9323
- item_ct1);
9324
- });
9325
- }
9326
-
9327
- static void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy,
9328
- float *dst, const int ncols,
9329
- const int nrows,
9330
- dpct::queue_ptr stream) {
9331
- GGML_ASSERT(ncols % QK_K == 0);
9332
- const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
9333
- const sycl::range<3> block_nums(1, 1, block_num_y);
9334
- const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
9335
- stream->parallel_for(
9336
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
9337
- [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
9338
- mul_mat_vec_q<QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ,
9339
- vec_dot_q5_K_q8_1>(vx, vy, dst, ncols, nrows,
9340
- item_ct1);
9341
- });
9342
- }
9343
-
9344
- static void mul_mat_vec_q6_K_q8_1_sycl(const void *vx, const void *vy,
9345
- float *dst, const int ncols,
9346
- const int nrows,
9347
- dpct::queue_ptr stream) {
9348
- GGML_ASSERT(ncols % QK_K == 0);
9349
- const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
9350
- const sycl::range<3> block_nums(1, 1, block_num_y);
9351
- const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
9352
- stream->parallel_for(
9353
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
9354
- [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
9355
- mul_mat_vec_q<QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ,
9356
- vec_dot_q6_K_q8_1>(vx, vy, dst, ncols, nrows,
9357
- item_ct1);
9358
- });
9191
+ template <int qk, int qi, typename block_q_t, int vdr,
9192
+ vec_dot_q_sycl_t vec_dot_q_sycl>
9193
+ static void mul_mat_vec_q_sycl_submitter(const void *vx, const void *vy,
9194
+ float *dst, const int ncols,
9195
+ const int nrows,
9196
+ dpct::queue_ptr stream) {
9197
+ GGML_ASSERT(ncols % QK4_0 == 0);
9198
+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
9199
+ const sycl::range<3> block_nums(1, 1, block_num_y);
9200
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
9201
+ stream->parallel_for(
9202
+ sycl::nd_range<3>(block_nums * block_dims, block_dims), [=
9203
+ ](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
9204
+ mul_mat_vec_q<qk, qi, block_q_t, vdr, vec_dot_q_sycl>(
9205
+ vx, vy, dst, ncols, nrows, item_ct1);
9206
+ });
9359
9207
}
9360
9208
9361
9209
int get_device_index_by_id(int id){
@@ -12095,37 +11943,63 @@ inline void ggml_sycl_op_mul_mat_vec_q(
12095
11943
const int64_t ne00 = src0->ne[0];
12096
11944
const int64_t row_diff = row_high - row_low;
12097
11945
11946
+ // TODO: support these quantization types
11947
+ GGML_ASSERT(!(src0->type == GGML_TYPE_IQ2_XXS ||
11948
+ src0->type == GGML_TYPE_IQ2_XS ||
11949
+ src0->type == GGML_TYPE_IQ3_XXS ||
11950
+ src0->type == GGML_TYPE_IQ1_S));
11951
+
12098
11952
switch (src0->type) {
12099
11953
case GGML_TYPE_Q4_0:
12100
- mul_mat_vec_q4_0_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
12101
- break;
11954
+ mul_mat_vec_q_sycl_submitter<QK4_0, QI4_0, block_q4_0,
11955
+ VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>(
11956
+ src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
11957
+ break;
12102
11958
case GGML_TYPE_Q4_1:
12103
- mul_mat_vec_q4_1_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
12104
- break;
11959
+ mul_mat_vec_q_sycl_submitter<QK4_1, QI4_1, block_q4_1,
11960
+ VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>(
11961
+ src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
11962
+ break;
12105
11963
case GGML_TYPE_Q5_0:
12106
- mul_mat_vec_q5_0_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
12107
- break;
11964
+ mul_mat_vec_q_sycl_submitter<QK5_0, QI5_0, block_q5_0,
11965
+ VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>(
11966
+ src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
11967
+ break;
12108
11968
case GGML_TYPE_Q5_1:
12109
- mul_mat_vec_q5_1_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
12110
- break;
11969
+ mul_mat_vec_q_sycl_submitter<QK5_1, QI5_1, block_q5_1,
11970
+ VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>(
11971
+ src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
11972
+ break;
12111
11973
case GGML_TYPE_Q8_0:
12112
- mul_mat_vec_q8_0_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
12113
- break;
11974
+ mul_mat_vec_q_sycl_submitter<QK8_0, QI8_0, block_q8_0,
11975
+ VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>(
11976
+ src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
11977
+ break;
12114
11978
case GGML_TYPE_Q2_K:
12115
- mul_mat_vec_q2_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
12116
- break;
11979
+ mul_mat_vec_q_sycl_submitter<QK_K, QI2_K, block_q2_K,
11980
+ VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>(
11981
+ src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
11982
+ break;
12117
11983
case GGML_TYPE_Q3_K:
12118
- mul_mat_vec_q3_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
12119
- break;
11984
+ mul_mat_vec_q_sycl_submitter<QK_K, QI3_K, block_q3_K,
11985
+ VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>(
11986
+ src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
11987
+ break;
12120
11988
case GGML_TYPE_Q4_K:
12121
- mul_mat_vec_q4_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
12122
- break;
11989
+ mul_mat_vec_q_sycl_submitter<QK_K, QI4_K, block_q4_K,
11990
+ VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>(
11991
+ src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
11992
+ break;
12123
11993
case GGML_TYPE_Q5_K:
12124
- mul_mat_vec_q5_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
12125
- break;
11994
+ mul_mat_vec_q_sycl_submitter<QK_K, QI5_K, block_q5_K,
11995
+ VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>(
11996
+ src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
11997
+ break;
12126
11998
case GGML_TYPE_Q6_K:
12127
- mul_mat_vec_q6_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
12128
- break;
11999
+ mul_mat_vec_q_sycl_submitter<QK_K, QI6_K, block_q6_K,
12000
+ VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>(
12001
+ src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
12002
+ break;
12129
12003
default:
12130
12004
GGML_ASSERT(false);
12131
12005
break;
@@ -12145,7 +12019,7 @@ inline void ggml_sycl_op_dequantize_mul_mat_vec(
12145
12019
const int64_t src1_ncols, const int64_t src1_padded_row_size,
12146
12020
const dpct::queue_ptr &stream) {
12147
12021
12148
- GGML_TENSOR_BINARY_OP_LOCALS
12022
+ GGML_TENSOR_BINARY_OP_LOCALS;
12149
12023
12150
12024
const int64_t row_diff = row_high - row_low;
12151
12025
@@ -15093,6 +14967,12 @@ static bool ggml_backend_sycl_supports_op(ggml_backend_t backend, const ggml_ten
15093
14967
return false;
15094
14968
}
15095
14969
14970
+ if (a->type == GGML_TYPE_IQ1_S) {
14971
+ return false;
14972
+ }
14973
+ if (a->type == GGML_TYPE_IQ3_XXS) {
14974
+ return false;
14975
+ }
15096
14976
if (a->type == GGML_TYPE_IQ2_XXS) {
15097
14977
return false;
15098
14978
}
0 commit comments