@@ -3493,10 +3493,6 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
3493
3493
SYCL_CHECK (ggml_sycl_set_device (ctx.device ));
3494
3494
queue_ptr main_stream = ctx.stream ();;
3495
3495
3496
- bool no_mixed_dtypes = main_stream->get_backend () == sycl::backend::ext_oneapi_cuda ||
3497
- main_stream->get_backend () == sycl::backend::ext_oneapi_hip;
3498
-
3499
-
3500
3496
void * src0_ddq = src0->data ;
3501
3497
sycl::half *src0_as_f16 = (sycl::half *)src0_ddq;
3502
3498
float * src1_ddf = (float *) src1->data ;
@@ -3514,15 +3510,10 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
3514
3510
sycl::half *src1_f16 = src1->type == GGML_TYPE_F16 ? (sycl::half *)src1_ddf
3515
3511
: src1_f16_alloc.get ();
3516
3512
3517
- ggml_sycl_pool_alloc<sycl::half> dst_f16 (ctx.pool ());
3518
3513
char * dst_t ;
3519
3514
3520
3515
dpct::library_data_t cu_compute_type = dpct::library_data_t ::real_float;
3521
3516
dpct::library_data_t cu_data_type = dpct::library_data_t ::real_float;
3522
- if (no_mixed_dtypes) {
3523
- cu_compute_type = dpct::library_data_t ::real_half;
3524
- cu_data_type = dpct::library_data_t ::real_half;
3525
- }
3526
3517
3527
3518
// dst strides
3528
3519
size_t nbd2 = dst->nb [2 ];
@@ -3531,26 +3522,10 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
3531
3522
const float alpha_f32 = 1 .0f ;
3532
3523
const float beta_f32 = 0 .0f ;
3533
3524
3534
- const sycl::half alpha_f16 = 1 .0f ;
3535
- const sycl::half beta_f16 = 0 .0f ;
3536
-
3537
3525
const void * alpha = &alpha_f32;
3538
3526
const void * beta = &beta_f32;
3539
- if (no_mixed_dtypes) {
3540
- alpha = &alpha_f16;
3541
- beta = &beta_f16;
3542
- }
3543
-
3544
- // TODO: Renable (dst->op_params[0] =! GGML_PREC_DEFAULT) pathway
3545
- // when oneMKL open source supports half, half, float, float: datatypes
3546
3527
3547
3528
dst_t = (char *) dst_ddf;
3548
- if (no_mixed_dtypes) {
3549
- dst_t = (char *) dst_f16.alloc (ne_dst);
3550
-
3551
- nbd2 /= sizeof (float ) / sizeof (sycl::half);
3552
- nbd3 /= sizeof (float ) / sizeof (sycl::half);
3553
- }
3554
3529
3555
3530
GGML_ASSERT (ne12 % ne02 == 0 );
3556
3531
GGML_ASSERT (ne13 % ne03 == 0 );
@@ -3612,11 +3587,6 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
3612
3587
(void **)(ptrs_dst.get () + 0 * ne23), cu_data_type, ne01, ne23,
3613
3588
cu_compute_type)));
3614
3589
}
3615
-
3616
- if (no_mixed_dtypes) {
3617
- const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl (GGML_TYPE_F16);
3618
- to_fp32_sycl (dst_f16.get (), dst_ddf, ne_dst, main_stream);
3619
- }
3620
3590
}
3621
3591
catch (sycl::exception const &exc) {
3622
3592
std::cerr << exc.what () << " Exception caught at file:" << __FILE__
0 commit comments