@@ -2971,7 +2971,7 @@ static int g_work_group_size = 0;
2971
2971
// typedef sycl::half ggml_fp16_t;
2972
2972
2973
2973
#define __SYCL_ARCH__ DPCT_COMPATIBILITY_TEMP
2974
- #define VER_4VEC 610 //todo for hardward optimize.
2974
+ #define VER_4VEC 130 //todo for hardward optimize.
2975
2975
#define VER_GEN9 700 //todo for hardward optimize.
2976
2976
#define VER_GEN12 1000000 //todo for hardward optimize.
2977
2977
#define VER_GEN13 (VER_GEN12 + 1030) //todo for hardward optimize.
@@ -2984,7 +2984,7 @@ static int g_work_group_size = 0;
2984
2984
#define SYCL_USE_XMX
2985
2985
2986
2986
// max batch size to use MMQ kernels when tensor cores are available
2987
- #define XMX_MAX_BATCH_SIZE 32
2987
+ #define MMQ_MAX_BATCH_SIZE 32
2988
2988
2989
2989
2990
2990
#if defined(_MSC_VER)
@@ -15173,6 +15173,25 @@ catch (sycl::exception const &exc) {
15173
15173
std::exit(1);
15174
15174
}
15175
15175
15176
+ bool ggml_sycl_supports_mmq(enum ggml_type type) {
15177
+ // TODO: accuracy issues in MMQ
15178
+ return false;
15179
+ // switch (type) {
15180
+ // case GGML_TYPE_Q4_0:
15181
+ // case GGML_TYPE_Q4_1:
15182
+ // case GGML_TYPE_Q5_0:
15183
+ // case GGML_TYPE_Q5_1:
15184
+ // case GGML_TYPE_Q8_0:
15185
+ // case GGML_TYPE_Q2_K:
15186
+ // case GGML_TYPE_Q3_K:
15187
+ // case GGML_TYPE_Q4_K:
15188
+ // case GGML_TYPE_Q5_K:
15189
+ // case GGML_TYPE_Q6_K:
15190
+ // return true;
15191
+ // default:
15192
+ // return false;
15193
+ // }
15194
+ }
15176
15195
15177
15196
static void ggml_sycl_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
15178
15197
const bool all_on_device =
@@ -15189,75 +15208,59 @@ static void ggml_sycl_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
15189
15208
}
15190
15209
}
15191
15210
15211
+ #if !defined(GGML_SYCL_FORCE_MMQ)
15212
+ #define SYCL_USE_XMX
15213
+ #endif
15214
+
15192
15215
#ifdef SYCL_USE_XMX
15193
- const bool use_xmx = true;
15216
+ bool use_xmx = true;
15194
15217
#else
15195
- const bool use_xmx = false;
15218
+ bool use_xmx = false;
15196
15219
#endif
15197
15220
15198
- // debug helpers
15199
- //printf("src0: %8d %8d %8d %8d\n", src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]);
15200
- //printf(" %8d %8d %8d %8d\n", src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]);
15201
- //printf("src1: %8d %8d %8d %8d\n", src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3]);
15202
- //printf(" %8d %8d %8d %8d\n", src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3]);
15203
- //printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
15204
- //printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
15221
+ // check data types and tensor shapes for custom matrix multiplication kernels:
15222
+ bool use_dequantize_mul_mat_vec = (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16)
15223
+ && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
15224
+ && src0->ne[0] % GGML_SYCL_DMMV_X == 0 && src1->ne[1] == 1;
15205
15225
15206
- if (!split && all_on_device && !use_xmx && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
15226
+ bool use_mul_mat_vec_q = ggml_is_quantized(src0->type)
15227
+ && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
15228
+ && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
15229
+
15230
+ bool use_mul_mat_q = ggml_sycl_supports_mmq(src0->type)
15231
+ && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
15232
+
15233
+ // fp16 performance always better on gen12+
15234
+ const bool fp16_performance_good = true;
15235
+
15236
+ // mmvq and mmq need the __dp4a instruction which is available for gen12+
15237
+ use_mul_mat_vec_q = use_mul_mat_vec_q; // Check dp4a
15238
+ use_mul_mat_q = use_mul_mat_q ; // check dp4a
15239
+ #ifdef SYCL_USE_XMX
15240
+ use_mul_mat_q = use_mul_mat_q && (!fp16_performance_good || src1->ne[1] <= MMQ_MAX_BATCH_SIZE);
15241
+ #endif // SYCL_USE_XMX
15242
+
15243
+ #ifdef GGML_SYCL_FORCE_DMMV
15244
+ use_dequantize_mul_mat_vec = use_dequantize_mul_mat_vec && !use_mul_mat_vec_q;
15245
+ #endif // GGML_SYCL_FORCE_DMMV
15246
+
15247
+ if (!split && !fp16_performance_good && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
15207
15248
// KQ single-batch
15208
- // GGML_SYCL_DEBUG("ggml_sycl_mul_mat_vec_p021\n");
15209
15249
ggml_sycl_mul_mat_vec_p021(src0, src1, dst);
15210
- } else if (!split && all_on_device && !use_xmx && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
15250
+ } else if (!split && !fp16_performance_good && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
15211
15251
// KQV single-batch
15212
- // GGML_SYCL_DEBUG("ggml_sycl_mul_mat_vec_nc\n");
15213
15252
ggml_sycl_mul_mat_vec_nc(src0, src1, dst);
15214
- } else if (!split && all_on_device && use_xmx && src0 ->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1)) {
15253
+ } else if (!split && src0->type == GGML_TYPE_F16 && (src1 ->type == GGML_TYPE_F16 || fp16_performance_good) && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1 ) {
15215
15254
// KQ + KQV multi-batch
15216
- // GGML_SYCL_DEBUG("ggml_sycl_mul_mat_batched_sycl\n");
15217
15255
ggml_sycl_mul_mat_batched_sycl(src0, src1, dst);
15218
- } else if (src0->type == GGML_TYPE_F32) {
15219
- // GGML_SYCL_DEBUG("ggml_sycl_op_mul_mat\n");
15220
- ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_sycl, false);
15221
- } else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) {
15222
- // GGML_SYCL_DEBUG("ggml_is_quantized or GGML_TYPE_F16\n");
15223
- if (src1->ne[1] == 1 && src0->ne[0] % GGML_SYCL_DMMV_X == 0) {
15224
- #ifdef GGML_SYCL_FORCE_DMMV
15225
- const bool use_mul_mat_vec_q = false;
15226
- #else
15227
- bool use_mul_mat_vec_q = min_compute_capability >= VER_4VEC && ggml_is_quantized(src0->type);
15228
- use_mul_mat_vec_q = use_mul_mat_vec_q ||
15229
- (src0->type == GGML_TYPE_IQ2_XXS) || (src0->type == GGML_TYPE_IQ2_XS) || (src0->type == GGML_TYPE_IQ2_S) ||
15230
- (src0->type == GGML_TYPE_IQ3_XXS) || (src0->type == GGML_TYPE_IQ3_S) ||
15231
- (src0->type == GGML_TYPE_IQ4_NL) || (src0->type == GGML_TYPE_IQ4_XS) ||
15232
- (src0->type == GGML_TYPE_IQ1_S) || (src0->type == GGML_TYPE_IQ1_M);
15233
-
15234
-
15235
- #endif // GGML_SYCL_FORCE_DMMV
15236
-
15237
- if (use_mul_mat_vec_q) {
15238
- // GGML_SYCL_DEBUG("ggml_sycl_mul_mat ggml_sycl_op_mul_mat_vec_q path\n");
15239
- ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, true);
15240
- } else {
15241
- // GGML_SYCL_DEBUG("ggml_sycl_mul_mat ggml_sycl_op_dequantize_mul_mat_vec path\n");
15242
- ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, false);
15243
- }
15244
- } else {
15245
- bool use_mul_mat_q = min_compute_capability >= VER_4VEC && ggml_is_quantized(src0->type);
15246
-
15247
- if (use_xmx && min_compute_capability >= VER_GEN9 && src1->ne[1] > XMX_MAX_BATCH_SIZE) {
15248
- use_mul_mat_q = false;
15249
- }
15250
-
15251
- if (use_mul_mat_q) {
15252
- // GGML_SYCL_DEBUG("ggml_sycl_mul_mat ggml_sycl_op_mul_mat_q path\n");
15253
- ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_q, true);
15254
- } else {
15255
- // GGML_SYCL_DEBUG("ggml_sycl_mul_mat ggml_sycl_op_mul_mat_sycl path\n");
15256
- ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_sycl, false);
15257
- }
15258
- }
15256
+ } else if (use_dequantize_mul_mat_vec) {
15257
+ ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, false);
15258
+ } else if (use_mul_mat_vec_q) {
15259
+ ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, true);
15260
+ } else if (use_mul_mat_q) {
15261
+ ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_q, true);
15259
15262
} else {
15260
- GGML_ASSERT( false);
15263
+ ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_sycl, false);
15261
15264
}
15262
15265
}
15263
15266
0 commit comments