Skip to content

Commit 40520b1

Browse files
committed
update mul_mat condition
1 parent 33008f1 commit 40520b1

File tree

1 file changed

+2
-18
lines changed

1 file changed

+2
-18
lines changed

ggml-sycl.cpp

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,6 @@ Following definition copied from DPCT head files, which are used by ggml-sycl.cp
8585
#endif
8686

8787
bool ggml_sycl_loaded(void);
88-
bool ggml_sycl_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
8988
void ggml_sycl_free_data(struct ggml_tensor * tensor);
9089
void ggml_sycl_assign_buffers(struct ggml_tensor * tensor);
9190
void ggml_sycl_assign_buffers_no_scratch(struct ggml_tensor * tensor);
@@ -11434,21 +11433,6 @@ static void ggml_sycl_rms_norm(ggml_backend_sycl_context & ctx, const ggml_tenso
1143411433
GGML_SYCL_DEBUG("call %s done\n", __func__);
1143511434
}
1143611435

11437-
bool ggml_sycl_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
11438-
if (!g_sycl_loaded) return false;
11439-
11440-
const int64_t ne10 = src1->ne[0];
11441-
11442-
const int64_t ne0 = dst->ne[0];
11443-
const int64_t ne1 = dst->ne[1];
11444-
11445-
// TODO: find the optimal values for these
11446-
return (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) &&
11447-
src1->type == GGML_TYPE_F32 &&
11448-
dst->type == GGML_TYPE_F32 &&
11449-
(ne0 >= 32 && ne1 >= 32 && ne10 >= 32);
11450-
}
11451-
1145211436
static void ggml_sycl_mul_mat_vec_p021(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
1145311437
const ggml_tensor *src1,
1145411438
ggml_tensor *dst) try {
@@ -12254,13 +12238,13 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens
1225412238
func = ggml_sycl_rms_norm;
1225512239
break;
1225612240
case GGML_OP_MUL_MAT:
12257-
if (ggml_sycl_can_mul_mat(tensor->src[0], tensor->src[1], tensor)) {
12241+
if (tensor->src[0]->ne[3] != tensor->src[1]->ne[3]) {
1225812242
return false;
1225912243
}
1226012244
func = ggml_sycl_mul_mat;
1226112245
break;
1226212246
case GGML_OP_MUL_MAT_ID:
12263-
if (ggml_sycl_can_mul_mat(tensor->src[2], tensor->src[1], tensor)) {
12247+
if (tensor->src[0]->ne[3] != tensor->src[1]->ne[3]) {
1226412248
return false;
1226512249
}
1226612250
func = ggml_sycl_mul_mat_id;

0 commit comments

Comments
 (0)