@@ -85,7 +85,6 @@ Following definition copied from DPCT head files, which are used by ggml-sycl.cp
85
85
#endif
86
86
87
87
bool ggml_sycl_loaded(void);
88
- bool ggml_sycl_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
89
88
void ggml_sycl_free_data(struct ggml_tensor * tensor);
90
89
void ggml_sycl_assign_buffers(struct ggml_tensor * tensor);
91
90
void ggml_sycl_assign_buffers_no_scratch(struct ggml_tensor * tensor);
@@ -11434,21 +11433,6 @@ static void ggml_sycl_rms_norm(ggml_backend_sycl_context & ctx, const ggml_tenso
11434
11433
GGML_SYCL_DEBUG("call %s done\n", __func__);
11435
11434
}
11436
11435
11437
- bool ggml_sycl_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
11438
- if (!g_sycl_loaded) return false;
11439
-
11440
- const int64_t ne10 = src1->ne[0];
11441
-
11442
- const int64_t ne0 = dst->ne[0];
11443
- const int64_t ne1 = dst->ne[1];
11444
-
11445
- // TODO: find the optimal values for these
11446
- return (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) &&
11447
- src1->type == GGML_TYPE_F32 &&
11448
- dst->type == GGML_TYPE_F32 &&
11449
- (ne0 >= 32 && ne1 >= 32 && ne10 >= 32);
11450
- }
11451
-
11452
11436
static void ggml_sycl_mul_mat_vec_p021(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
11453
11437
const ggml_tensor *src1,
11454
11438
ggml_tensor *dst) try {
@@ -12254,13 +12238,13 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens
12254
12238
func = ggml_sycl_rms_norm;
12255
12239
break;
12256
12240
case GGML_OP_MUL_MAT:
12257
- if (ggml_sycl_can_mul_mat( tensor->src[0], tensor->src[1], tensor) ) {
12241
+ if (tensor->src[0]->ne[3] != tensor->src[1]->ne[3] ) {
12258
12242
return false;
12259
12243
}
12260
12244
func = ggml_sycl_mul_mat;
12261
12245
break;
12262
12246
case GGML_OP_MUL_MAT_ID:
12263
- if (ggml_sycl_can_mul_mat( tensor->src[2], tensor->src[1], tensor) ) {
12247
+ if (tensor->src[0]->ne[3] != tensor->src[1]->ne[3] ) {
12264
12248
return false;
12265
12249
}
12266
12250
func = ggml_sycl_mul_mat_id;
0 commit comments