Skip to content

Commit 603fdbf

Browse files
committed
Fix binbcast
1 parent 4731850 commit 603fdbf

File tree

2 files changed

+9
-8
lines changed

2 files changed

+9
-8
lines changed

ggml/src/ggml-sycl/binbcast.cpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "binbcast.hpp"
22
#include <sycl/sycl.hpp>
3+
#include "ggml.h"
34

45
template<float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
56
static void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst,
@@ -263,16 +264,16 @@ inline void ggml_sycl_op_bin_bcast(ggml_backend_sycl_context & ctx, const ggml_t
263264

264265
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
265266
op()(ctx, src0, src1, dst, (const float *)src0->data, (const float *)src1->data, (float *)dst->data, main_stream);
266-
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
267-
op()(ctx, src0, src1, dst, (const sycl::half *)src0->data, (const float *)src1->data,
267+
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
268+
op()(ctx, src0, src1, dst, (const sycl::half *)src0->data, (const sycl::half *)src1->data,
268269
(sycl::half *)dst->data, main_stream);
269-
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
270-
op()(ctx, src0, src1, dst, (const sycl::half *)src0->data, (const float *)src1->data, (float *)dst->data,
270+
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
271+
op()(ctx, src0, src1, dst, (const sycl::half *)src0->data, (const float *)src1->data, (sycl::half *)dst->data,
271272
main_stream);
272-
} else if (src0->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I32) {
273+
} else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I32) {
273274
op()(ctx, src0, src1, dst, (const int32_t *)src0->data, (const int32_t *)src1->data, (int32_t *)dst->data,
274275
main_stream);
275-
} else if (src0->type == GGML_TYPE_I16 && dst->type == GGML_TYPE_I16) {
276+
} else if (src0->type == GGML_TYPE_I16 && src1->type == GGML_TYPE_I16 && dst->type == GGML_TYPE_I16) {
276277
op()(ctx, src0, src1, dst, (const int16_t *)src0->data, (const int16_t *)src1->data, (int16_t *)dst->data,
277278
main_stream);
278279
} else {

ggml/src/ggml-sycl/ggml-sycl.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3961,7 +3961,6 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
39613961
case GGML_OP_ARGMAX:
39623962
case GGML_OP_NONE:
39633963
case GGML_OP_RESHAPE:
3964-
case GGML_OP_REPEAT:
39653964
case GGML_OP_VIEW:
39663965
case GGML_OP_PERMUTE:
39673966
case GGML_OP_TRANSPOSE:
@@ -3971,7 +3970,8 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
39713970
case GGML_OP_SUB:
39723971
case GGML_OP_MUL:
39733972
case GGML_OP_DIV:
3974-
return (op->src[0]->type == GGML_TYPE_F32);
3973+
case GGML_OP_REPEAT:
3974+
return true;
39753975
case GGML_OP_SQR:
39763976
case GGML_OP_SQRT:
39773977
case GGML_OP_SIN:

0 commit comments

Comments
 (0)