Skip to content

Commit 38d1521

Browse files
[SYCL] Use batched mul_mat pathway (#5591)
* Use batched mul_mat pathway * rm extra line * Explicitly state scaled data type --------- Co-authored-by: Abhilash Majumder <[email protected]>
1 parent 052051d commit 38d1521

File tree

1 file changed

+44
-63
lines changed

1 file changed

+44
-63
lines changed

ggml-sycl.cpp

Lines changed: 44 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -12726,6 +12726,7 @@ static void ggml_sycl_op_mul_mat(const ggml_tensor *src0,
1272612726

1272712727
GGML_ASSERT(dst->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
1272812728
GGML_ASSERT(src1->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
12729+
GGML_ASSERT(src1->type == GGML_TYPE_F32 || (src1->ne[2] == 1 && src1->ne[3] == 1));
1272912730

1273012731
GGML_ASSERT(ne12 >= ne02 && ne12 % ne02 == 0);
1273112732

@@ -13269,31 +13270,23 @@ static void k_compute_batched_ptrs(const sycl::half *src0_as_f16,
1326913270
int64_t i03 = i13 / r3;
1327013271
int64_t i02 = i12 / r2;
1327113272

13272-
ptrs_src[0*ne23 + i12 + i13*ne12] = (const char *) src0_as_f16 + i02*nb02 + i03*nb03;
13273-
ptrs_src[1*ne23 + i12 + i13*ne12] = (const char *) src1_as_f16 + i12*nb12/2 + i13*nb13/2;
13274-
ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst + i12*nbd2 + i13*nbd3;
13273+
ptrs_src[0*ne23 + i12 + i13*ne12] = (const char *) src0_as_f16 + i02*nb02 + i03*nb03;
13274+
ptrs_src[1*ne23 + i12 + i13*ne12] = (const char *) src1_as_f16 + i12*nb12 + i13*nb13;
13275+
ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst + i12*nbd2 + i13*nbd3;
1327513276
}
1327613277

13277-
static void ggml_sycl_mul_mat_mat_batched_sycl(const ggml_tensor *src0,
13278-
const ggml_tensor *src1,
13279-
ggml_tensor *dst) try {
13278+
static void ggml_sycl_mul_mat_batched_sycl(const ggml_tensor *src0,
13279+
const ggml_tensor *src1,
13280+
ggml_tensor *dst) try {
1328013281
GGML_ASSERT(!ggml_is_transposed(src0));
1328113282
GGML_ASSERT(!ggml_is_transposed(src1));
1328213283

1328313284
GGML_ASSERT(src0->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
1328413285
GGML_ASSERT(src0->type == GGML_TYPE_F16);
13285-
GGML_ASSERT(src1->type == GGML_TYPE_F32);
13286-
13287-
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne);
13288-
13289-
GGML_TENSOR_LOCALS(int64_t, nb0, src0, nb);
1329013286

13291-
GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne);
13292-
13293-
GGML_TENSOR_LOCALS(int64_t, nb1, src1, nb);
13287+
GGML_TENSOR_BINARY_OP_LOCALS
1329413288

13295-
const int64_t ne1 = ggml_nelements(src1);
13296-
const int64_t ne = ggml_nelements(dst);
13289+
const int64_t ne_dst = ggml_nelements(dst);
1329713290

1329813291
SYCL_CHECK(ggml_sycl_set_device(g_main_device));
1329913292
dpct::queue_ptr main_stream = g_syclStreams[g_main_device_index][0];
@@ -13312,11 +13305,16 @@ static void ggml_sycl_mul_mat_mat_batched_sycl(const ggml_tensor *src0,
1331213305
float * dst_ddf = (float *) dst_extra->data_device[g_main_device_index];
1331313306

1331413307
// convert src1 to fp16
13315-
const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type);
13316-
GGML_ASSERT(to_fp16_sycl != nullptr);
13317-
13318-
sycl_pool_alloc<sycl::half> src1_as_f16(ne1);
13319-
to_fp16_sycl(src1_ddf, src1_as_f16.get(), ne1, main_stream);
13308+
sycl_pool_alloc<sycl::half> src1_f16_alloc;
13309+
if (src1->type != GGML_TYPE_F16) {
13310+
const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type);
13311+
const int64_t ne_src1 = ggml_nelements(src1);
13312+
src1_f16_alloc.alloc(ne_src1);
13313+
GGML_ASSERT(to_fp16_sycl != nullptr);
13314+
to_fp16_sycl(src1_ddf, src1_f16_alloc.get(), ne_src1, main_stream);
13315+
}
13316+
sycl::half *src1_f16 = src1->type == GGML_TYPE_F16 ? (sycl::half *)src1_ddf
13317+
: src1_f16_alloc.get();
1332013318

1332113319
sycl_pool_alloc<sycl::half> dst_f16;
1332213320
char * dst_t;
@@ -13337,20 +13335,12 @@ static void ggml_sycl_mul_mat_mat_batched_sycl(const ggml_tensor *src0,
1333713335
const void * alpha = &alpha_f16;
1333813336
const void * beta = &beta_f16;
1333913337

13340-
if (dst->op_params[0] == GGML_PREC_DEFAULT) {
13341-
dst_t = (char *) dst_f16.alloc(ne);
13342-
13343-
nbd2 /= sizeof(float) / sizeof(sycl::half);
13344-
nbd3 /= sizeof(float) / sizeof(sycl::half);
13345-
} else {
13346-
dst_t = (char *) dst_ddf;
13347-
13348-
cu_compute_type = dpct::library_data_t::real_float;
13349-
cu_data_type = dpct::library_data_t::real_float;
13338+
// TODO: Renable (dst->op_params[0] =! GGML_PREC_DEFAULT) pathway
13339+
// once oneMKL open source supports half, half, float, float: datatypes
13340+
dst_t = (char *) dst_f16.alloc(ne_dst);
1335013341

13351-
alpha = &alpha_f32;
13352-
beta = &beta_f32;
13353-
}
13342+
nbd2 /= sizeof(float) / sizeof(sycl::half);
13343+
nbd3 /= sizeof(float) / sizeof(sycl::half);
1335413344

1335513345
GGML_ASSERT(ne12 % ne02 == 0);
1335613346
GGML_ASSERT(ne13 % ne03 == 0);
@@ -13386,10 +13376,10 @@ static void ggml_sycl_mul_mat_mat_batched_sycl(const ggml_tensor *src0,
1338613376
*g_sycl_handles[g_main_device_index], oneapi::mkl::transpose::trans,
1338713377
oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha,
1338813378
(const char *)src0_as_f16, dpct::library_data_t::real_half,
13389-
nb01 / sizeof(sycl::half), src0->nb[2] / sizeof(sycl::half),
13390-
(const char *)src1_as_f16.get(), dpct::library_data_t::real_half,
13391-
nb11 / sizeof(float), src1->nb[2] / sizeof(float), beta,
13392-
(char *)dst_t, cu_data_type, ne01, dst->nb[2] / sizeof(float),
13379+
nb01 / nb00, nb02 / nb00,
13380+
(const char *)src1_f16, dpct::library_data_t::real_half,
13381+
nb11 / nb10, nb12 / nb10, beta,
13382+
(char *)dst_t, cu_data_type, ne01, nb2 / nb0,
1339313383
ne12 * ne13, cu_compute_type)));
1339413384
} else {
1339513385
// use syclGemmBatchedEx
@@ -13409,44 +13399,35 @@ static void ggml_sycl_mul_mat_mat_batched_sycl(const ggml_tensor *src0,
1340913399
{sycl::aspect::fp16});
1341013400

1341113401
main_stream->submit([&](sycl::handler &cgh) {
13412-
const sycl::half *src1_as_f16_get_ct1 = src1_as_f16.get();
13413-
const void **ptrs_src_get_ct3 = ptrs_src.get();
13414-
void **ptrs_dst_get_ct4 = ptrs_dst.get();
13415-
13402+
const void **ptrs_src_get = ptrs_src.get();
13403+
void **ptrs_dst_get = ptrs_dst.get();
13404+
size_t nb12_scaled = src1->type == GGML_TYPE_F16 ? nb12 : nb12 / 2;
13405+
size_t nb13_scaled = src1->type == GGML_TYPE_F16 ? nb13 : nb13 / 2;
1341613406
cgh.parallel_for(sycl::nd_range<3>(block_dims, block_dims),
1341713407
[=](sycl::nd_item<3> item_ct1) {
1341813408
k_compute_batched_ptrs(
13419-
src0_as_f16, src1_as_f16_get_ct1,
13420-
dst_t, ptrs_src_get_ct3,
13421-
ptrs_dst_get_ct4, ne12, ne13, ne23,
13422-
nb02, nb03, nb12, nb13, nbd2, nbd3, r2,
13423-
r3, item_ct1);
13409+
src0_as_f16, src1_f16,
13410+
dst_t, ptrs_src_get,
13411+
ptrs_dst_get, ne12, ne13, ne23,
13412+
nb02, nb03, nb12_scaled, nb13_scaled,
13413+
nbd2, nbd3, r2, r3, item_ct1);
1342413414
});
1342513415
});
1342613416
}
13427-
/*
13428-
DPCT1010:95: SYCL uses exceptions to report errors and does not use the
13429-
error codes. The call was replaced with 0. You need to rewrite this
13430-
code.
13431-
*/
13432-
SYCL_CHECK(0);
13433-
1343413417
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
1343513418
*g_sycl_handles[g_main_device_index], oneapi::mkl::transpose::trans,
1343613419
oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha,
1343713420
(const void **)(ptrs_src.get() + 0 * ne23),
13438-
dpct::library_data_t::real_half, nb01 / sizeof(sycl::half),
13421+
dpct::library_data_t::real_half, nb01 / nb00,
1343913422
(const void **)(ptrs_src.get() + 1 * ne23),
13440-
dpct::library_data_t::real_half, nb11 / sizeof(float), beta,
13423+
dpct::library_data_t::real_half, nb11 / nb10, beta,
1344113424
(void **)(ptrs_dst.get() + 0 * ne23), cu_data_type, ne01, ne23,
1344213425
cu_compute_type)));
1344313426
}
1344413427
#endif
1344513428

13446-
if (dst->op_params[0] == GGML_PREC_DEFAULT) {
13447-
const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16);
13448-
to_fp32_sycl(dst_f16.get(), dst_ddf, ne, main_stream);
13449-
}
13429+
const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16);
13430+
to_fp32_sycl(dst_f16.get(), dst_ddf, ne_dst, main_stream);
1345013431
}
1345113432
catch (sycl::exception const &exc) {
1345213433
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
@@ -13491,10 +13472,10 @@ static void ggml_sycl_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
1349113472
// KQV single-batch
1349213473
// GGML_SYCL_DEBUG("ggml_sycl_mul_mat_vec_nc\n");
1349313474
ggml_sycl_mul_mat_vec_nc(src0, src1, dst);
13494-
} else if (!split && all_on_device && use_xmx && src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1)) {
13475+
} else if (!split && all_on_device && use_xmx && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1)) {
1349513476
// KQ + KQV multi-batch
13496-
// GGML_SYCL_DEBUG("ggml_sycl_mul_mat_mat_batched_sycl\n");
13497-
ggml_sycl_mul_mat_mat_batched_sycl(src0, src1, dst);
13477+
// GGML_SYCL_DEBUG("ggml_sycl_mul_mat_batched_sycl\n");
13478+
ggml_sycl_mul_mat_batched_sycl(src0, src1, dst);
1349813479
} else if (src0->type == GGML_TYPE_F32) {
1349913480
// GGML_SYCL_DEBUG("ggml_sycl_op_mul_mat\n");
1350013481
ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_sycl, false);

0 commit comments

Comments
 (0)