|
1 | 1 | #include "binbcast.hpp"
|
2 | 2 | #include <sycl/sycl.hpp>
|
| 3 | +#include "ggml.h" |
3 | 4 |
|
4 | 5 | template<float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
|
5 | 6 | static void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst,
|
@@ -263,16 +264,16 @@ inline void ggml_sycl_op_bin_bcast(ggml_backend_sycl_context & ctx, const ggml_t
|
263 | 264 |
|
264 | 265 | if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
265 | 266 | op()(ctx, src0, src1, dst, (const float *)src0->data, (const float *)src1->data, (float *)dst->data, main_stream);
|
266 |
| - } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { |
267 |
| - op()(ctx, src0, src1, dst, (const sycl::half *)src0->data, (const float *)src1->data, |
| 267 | + } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { |
| 268 | + op()(ctx, src0, src1, dst, (const sycl::half *)src0->data, (const sycl::half *)src1->data, |
268 | 269 | (sycl::half *)dst->data, main_stream);
|
269 |
| - } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) { |
270 |
| - op()(ctx, src0, src1, dst, (const sycl::half *)src0->data, (const float *)src1->data, (float *)dst->data, |
| 270 | + } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) { |
| 271 | + op()(ctx, src0, src1, dst, (const sycl::half *)src0->data, (const float *)src1->data, (sycl::half *)dst->data, |
271 | 272 | main_stream);
|
272 |
| - } else if (src0->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I32) { |
| 273 | + } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I32) { |
273 | 274 | op()(ctx, src0, src1, dst, (const int32_t *)src0->data, (const int32_t *)src1->data, (int32_t *)dst->data,
|
274 | 275 | main_stream);
|
275 |
| - } else if (src0->type == GGML_TYPE_I16 && dst->type == GGML_TYPE_I16) { |
| 276 | + } else if (src0->type == GGML_TYPE_I16 && src1->type == GGML_TYPE_I16 && dst->type == GGML_TYPE_I16) { |
276 | 277 | op()(ctx, src0, src1, dst, (const int16_t *)src0->data, (const int16_t *)src1->data, (int16_t *)dst->data,
|
277 | 278 | main_stream);
|
278 | 279 | } else {
|
|
0 commit comments