@@ -468,12 +468,6 @@ static __dpct_inline__ Tp* get_pointer(sycl::local_accessor<Tp, dim> acc) {
468
468
469
469
int64_t downsample_sycl_global_range (int64_t accumulate_block_num, int64_t block_size);
470
470
471
- typedef void (*ggml_sycl_op_flatten_t )(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
472
- const ggml_tensor *src1,
473
- ggml_tensor *dst, const float *src0_dd,
474
- const float *src1_dd, float *dst_dd,
475
- const queue_ptr &main_stream);
476
-
477
471
template <float (*bin_op)(const float , const float ), typename src0_t , typename src1_t , typename dst_t >
478
472
static void k_bin_bcast (const src0_t * src0, const src1_t * src1, dst_t * dst,
479
473
int ne0, int ne1, int ne2, int ne3,
@@ -731,24 +725,23 @@ struct bin_bcast_sycl {
731
725
732
726
template <class op >
733
727
inline void ggml_sycl_op_bin_bcast (ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
734
- const ggml_tensor *src1, ggml_tensor *dst,
735
- const float *src0_dd, const float *src1_dd,
736
- float *dst_dd,
737
- const queue_ptr &main_stream) {
728
+ const ggml_tensor *src1, ggml_tensor *dst) {
729
+ /* TODO: Refactor bbincast */
730
+ dpct::queue_ptr main_stream = ctx.stream ();
738
731
739
732
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
740
- op ()(ctx, src0, src1, dst, src0_dd, src1_dd, dst_dd , main_stream);
733
+ op ()(ctx, src0, src1, dst, ( const float *)src0-> data , ( const float *)src1-> data , ( float *)dst-> data , main_stream);
741
734
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
742
- op ()(ctx, src0, src1, dst, (const sycl::half *)src0_dd, src1_dd ,
743
- (sycl::half *)dst_dd , main_stream);
735
+ op ()(ctx, src0, src1, dst, (const sycl::half *)src0-> data , ( const float *)src1-> data ,
736
+ (sycl::half *)dst-> data , main_stream);
744
737
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
745
- op ()(ctx, src0, src1, dst, (const sycl::half *)src0_dd, src1_dd, dst_dd ,
738
+ op ()(ctx, src0, src1, dst, (const sycl::half *)src0-> data , ( const float *)src1-> data , ( float *)dst-> data ,
746
739
main_stream);
747
740
} else if (src0->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I32) {
748
- op ()(ctx, src0, src1, dst, (const int32_t *)src0_dd , (const int32_t *)src1_dd , (int32_t *)dst_dd ,
741
+ op ()(ctx, src0, src1, dst, (const int32_t *)src0-> data , (const int32_t *)src1-> data , (int32_t *)dst-> data ,
749
742
main_stream);
750
743
} else if (src0->type == GGML_TYPE_I16 && dst->type == GGML_TYPE_I16) {
751
- op ()(ctx, src0, src1, dst, (const int16_t *)src0_dd , (const int16_t *)src1_dd , (int16_t *)dst_dd ,
744
+ op ()(ctx, src0, src1, dst, (const int16_t *)src0-> data , (const int16_t *)src1-> data , (int16_t *)dst-> data ,
752
745
main_stream);
753
746
} else {
754
747
fprintf (stderr, " %s: unsupported types: dst: %s, src0: %s, src1: %s\n " , __func__,
@@ -758,8 +751,4 @@ inline void ggml_sycl_op_bin_bcast(ggml_backend_sycl_context & ctx, const ggml_t
758
751
}
759
752
760
753
bool gpu_has_xmx (sycl::device &dev);
761
-
762
- void ggml_sycl_op_flatten (ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
763
- const ggml_tensor *src1, ggml_tensor *dst,
764
- const ggml_sycl_op_flatten_t op);
765
754
#endif // GGML_SYCL_COMMON_HPP
0 commit comments