@@ -12726,6 +12726,7 @@ static void ggml_sycl_op_mul_mat(const ggml_tensor *src0,
12726
12726
12727
12727
GGML_ASSERT(dst->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
12728
12728
GGML_ASSERT(src1->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
12729
+ GGML_ASSERT(src1->type == GGML_TYPE_F32 || (src1->ne[2] == 1 && src1->ne[3] == 1));
12729
12730
12730
12731
GGML_ASSERT(ne12 >= ne02 && ne12 % ne02 == 0);
12731
12732
@@ -13269,31 +13270,23 @@ static void k_compute_batched_ptrs(const sycl::half *src0_as_f16,
13269
13270
int64_t i03 = i13 / r3;
13270
13271
int64_t i02 = i12 / r2;
13271
13272
13272
- ptrs_src[0*ne23 + i12 + i13*ne12] = (const char *) src0_as_f16 + i02*nb02 + i03*nb03;
13273
- ptrs_src[1*ne23 + i12 + i13*ne12] = (const char *) src1_as_f16 + i12*nb12/2 + i13*nb13/2 ;
13274
- ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst + i12*nbd2 + i13*nbd3;
13273
+ ptrs_src[0*ne23 + i12 + i13*ne12] = (const char *) src0_as_f16 + i02*nb02 + i03*nb03;
13274
+ ptrs_src[1*ne23 + i12 + i13*ne12] = (const char *) src1_as_f16 + i12*nb12 + i13*nb13;
13275
+ ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst + i12*nbd2 + i13*nbd3;
13275
13276
}
13276
13277
13277
- static void ggml_sycl_mul_mat_mat_batched_sycl (const ggml_tensor *src0,
13278
- const ggml_tensor *src1,
13279
- ggml_tensor *dst) try {
13278
+ static void ggml_sycl_mul_mat_batched_sycl (const ggml_tensor *src0,
13279
+ const ggml_tensor *src1,
13280
+ ggml_tensor *dst) try {
13280
13281
GGML_ASSERT(!ggml_is_transposed(src0));
13281
13282
GGML_ASSERT(!ggml_is_transposed(src1));
13282
13283
13283
13284
GGML_ASSERT(src0->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
13284
13285
GGML_ASSERT(src0->type == GGML_TYPE_F16);
13285
- GGML_ASSERT(src1->type == GGML_TYPE_F32);
13286
-
13287
- GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne);
13288
-
13289
- GGML_TENSOR_LOCALS(int64_t, nb0, src0, nb);
13290
13286
13291
- GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne);
13292
-
13293
- GGML_TENSOR_LOCALS(int64_t, nb1, src1, nb);
13287
+ GGML_TENSOR_BINARY_OP_LOCALS
13294
13288
13295
- const int64_t ne1 = ggml_nelements(src1);
13296
- const int64_t ne = ggml_nelements(dst);
13289
+ const int64_t ne_dst = ggml_nelements(dst);
13297
13290
13298
13291
SYCL_CHECK(ggml_sycl_set_device(g_main_device));
13299
13292
dpct::queue_ptr main_stream = g_syclStreams[g_main_device_index][0];
@@ -13312,11 +13305,16 @@ static void ggml_sycl_mul_mat_mat_batched_sycl(const ggml_tensor *src0,
13312
13305
float * dst_ddf = (float *) dst_extra->data_device[g_main_device_index];
13313
13306
13314
13307
// convert src1 to fp16
13315
- const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type);
13316
- GGML_ASSERT(to_fp16_sycl != nullptr);
13317
-
13318
- sycl_pool_alloc<sycl::half> src1_as_f16(ne1);
13319
- to_fp16_sycl(src1_ddf, src1_as_f16.get(), ne1, main_stream);
13308
+ sycl_pool_alloc<sycl::half> src1_f16_alloc;
13309
+ if (src1->type != GGML_TYPE_F16) {
13310
+ const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type);
13311
+ const int64_t ne_src1 = ggml_nelements(src1);
13312
+ src1_f16_alloc.alloc(ne_src1);
13313
+ GGML_ASSERT(to_fp16_sycl != nullptr);
13314
+ to_fp16_sycl(src1_ddf, src1_f16_alloc.get(), ne_src1, main_stream);
13315
+ }
13316
+ sycl::half *src1_f16 = src1->type == GGML_TYPE_F16 ? (sycl::half *)src1_ddf
13317
+ : src1_f16_alloc.get();
13320
13318
13321
13319
sycl_pool_alloc<sycl::half> dst_f16;
13322
13320
char * dst_t;
@@ -13337,20 +13335,12 @@ static void ggml_sycl_mul_mat_mat_batched_sycl(const ggml_tensor *src0,
13337
13335
const void * alpha = &alpha_f16;
13338
13336
const void * beta = &beta_f16;
13339
13337
13340
- if (dst->op_params[0] == GGML_PREC_DEFAULT) {
13341
- dst_t = (char *) dst_f16.alloc(ne);
13342
-
13343
- nbd2 /= sizeof(float) / sizeof(sycl::half);
13344
- nbd3 /= sizeof(float) / sizeof(sycl::half);
13345
- } else {
13346
- dst_t = (char *) dst_ddf;
13347
-
13348
- cu_compute_type = dpct::library_data_t::real_float;
13349
- cu_data_type = dpct::library_data_t::real_float;
13338
+ // TODO: Renable (dst->op_params[0] =! GGML_PREC_DEFAULT) pathway
13339
+ // once oneMKL open source supports half, half, float, float: datatypes
13340
+ dst_t = (char *) dst_f16.alloc(ne_dst);
13350
13341
13351
- alpha = &alpha_f32;
13352
- beta = &beta_f32;
13353
- }
13342
+ nbd2 /= sizeof(float) / sizeof(sycl::half);
13343
+ nbd3 /= sizeof(float) / sizeof(sycl::half);
13354
13344
13355
13345
GGML_ASSERT(ne12 % ne02 == 0);
13356
13346
GGML_ASSERT(ne13 % ne03 == 0);
@@ -13386,10 +13376,10 @@ static void ggml_sycl_mul_mat_mat_batched_sycl(const ggml_tensor *src0,
13386
13376
*g_sycl_handles[g_main_device_index], oneapi::mkl::transpose::trans,
13387
13377
oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha,
13388
13378
(const char *)src0_as_f16, dpct::library_data_t::real_half,
13389
- nb01 / sizeof(sycl::half), src0->nb[2] / sizeof(sycl::half) ,
13390
- (const char *)src1_as_f16.get() , dpct::library_data_t::real_half,
13391
- nb11 / sizeof(float), src1->nb[2] / sizeof(float) , beta,
13392
- (char *)dst_t, cu_data_type, ne01, dst->nb[2] / sizeof(float) ,
13379
+ nb01 / nb00, nb02 / nb00 ,
13380
+ (const char *)src1_f16 , dpct::library_data_t::real_half,
13381
+ nb11 / nb10, nb12 / nb10 , beta,
13382
+ (char *)dst_t, cu_data_type, ne01, nb2 / nb0 ,
13393
13383
ne12 * ne13, cu_compute_type)));
13394
13384
} else {
13395
13385
// use syclGemmBatchedEx
@@ -13409,44 +13399,35 @@ static void ggml_sycl_mul_mat_mat_batched_sycl(const ggml_tensor *src0,
13409
13399
{sycl::aspect::fp16});
13410
13400
13411
13401
main_stream->submit([&](sycl::handler &cgh) {
13412
- const sycl::half *src1_as_f16_get_ct1 = src1_as_f16 .get();
13413
- const void **ptrs_src_get_ct3 = ptrs_src .get();
13414
- void **ptrs_dst_get_ct4 = ptrs_dst.get() ;
13415
-
13402
+ const void **ptrs_src_get = ptrs_src .get();
13403
+ void **ptrs_dst_get = ptrs_dst .get();
13404
+ size_t nb12_scaled = src1->type == GGML_TYPE_F16 ? nb12 : nb12 / 2 ;
13405
+ size_t nb13_scaled = src1->type == GGML_TYPE_F16 ? nb13 : nb13 / 2;
13416
13406
cgh.parallel_for(sycl::nd_range<3>(block_dims, block_dims),
13417
13407
[=](sycl::nd_item<3> item_ct1) {
13418
13408
k_compute_batched_ptrs(
13419
- src0_as_f16, src1_as_f16_get_ct1 ,
13420
- dst_t, ptrs_src_get_ct3 ,
13421
- ptrs_dst_get_ct4 , ne12, ne13, ne23,
13422
- nb02, nb03, nb12, nb13, nbd2, nbd3, r2 ,
13423
- r3, item_ct1);
13409
+ src0_as_f16, src1_f16 ,
13410
+ dst_t, ptrs_src_get ,
13411
+ ptrs_dst_get , ne12, ne13, ne23,
13412
+ nb02, nb03, nb12_scaled, nb13_scaled ,
13413
+ nbd2, nbd3, r2, r3, item_ct1);
13424
13414
});
13425
13415
});
13426
13416
}
13427
- /*
13428
- DPCT1010:95: SYCL uses exceptions to report errors and does not use the
13429
- error codes. The call was replaced with 0. You need to rewrite this
13430
- code.
13431
- */
13432
- SYCL_CHECK(0);
13433
-
13434
13417
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
13435
13418
*g_sycl_handles[g_main_device_index], oneapi::mkl::transpose::trans,
13436
13419
oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha,
13437
13420
(const void **)(ptrs_src.get() + 0 * ne23),
13438
- dpct::library_data_t::real_half, nb01 / sizeof(sycl::half) ,
13421
+ dpct::library_data_t::real_half, nb01 / nb00 ,
13439
13422
(const void **)(ptrs_src.get() + 1 * ne23),
13440
- dpct::library_data_t::real_half, nb11 / sizeof(float) , beta,
13423
+ dpct::library_data_t::real_half, nb11 / nb10 , beta,
13441
13424
(void **)(ptrs_dst.get() + 0 * ne23), cu_data_type, ne01, ne23,
13442
13425
cu_compute_type)));
13443
13426
}
13444
13427
#endif
13445
13428
13446
- if (dst->op_params[0] == GGML_PREC_DEFAULT) {
13447
- const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16);
13448
- to_fp32_sycl(dst_f16.get(), dst_ddf, ne, main_stream);
13449
- }
13429
+ const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16);
13430
+ to_fp32_sycl(dst_f16.get(), dst_ddf, ne_dst, main_stream);
13450
13431
}
13451
13432
catch (sycl::exception const &exc) {
13452
13433
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
@@ -13491,10 +13472,10 @@ static void ggml_sycl_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
13491
13472
// KQV single-batch
13492
13473
// GGML_SYCL_DEBUG("ggml_sycl_mul_mat_vec_nc\n");
13493
13474
ggml_sycl_mul_mat_vec_nc(src0, src1, dst);
13494
- } else if (!split && all_on_device && use_xmx && src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1)) {
13475
+ } else if (!split && all_on_device && use_xmx && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1)) {
13495
13476
// KQ + KQV multi-batch
13496
- // GGML_SYCL_DEBUG("ggml_sycl_mul_mat_mat_batched_sycl \n");
13497
- ggml_sycl_mul_mat_mat_batched_sycl (src0, src1, dst);
13477
+ // GGML_SYCL_DEBUG("ggml_sycl_mul_mat_batched_sycl \n");
13478
+ ggml_sycl_mul_mat_batched_sycl (src0, src1, dst);
13498
13479
} else if (src0->type == GGML_TYPE_F32) {
13499
13480
// GGML_SYCL_DEBUG("ggml_sycl_op_mul_mat\n");
13500
13481
ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_sycl, false);
0 commit comments