Skip to content

Commit 40c6d79

Browse files
authored
SYCL : Move to compile time oneMKL interface backend selection for NVIDIA backend (#10584)
* [SYCL] Move to Compile Time backend selection on oneMKL Interface for NVIDIA backend Move to compile time selection to backend to avoid latency at run time. Add it to all mkl gemm calls and only for NVIDIA backend. Signed-off-by: nscipione <[email protected]> * Formatting * Address PR comments to increase readibility --------- Signed-off-by: nscipione <[email protected]>
1 parent 98036d5 commit 40c6d79

File tree

4 files changed

+50
-25
lines changed

4 files changed

+50
-25
lines changed

ggml/src/ggml-sycl/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,8 @@ else()
6868
target_link_libraries(ggml-sycl PRIVATE sycl OpenCL mkl_core pthread m dl mkl_sycl_blas mkl_intel_ilp64 mkl_tbb_thread)
6969
elseif (GGML_SYCL_TARGET STREQUAL "NVIDIA")
7070
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsycl-targets=nvptx64-nvidia-cuda")
71-
target_link_libraries(ggml-sycl PRIVATE sycl pthread m dl onemkl)
71+
add_compile_definitions(GGML_SYCL_NVIDIA)
72+
target_link_libraries(ggml-sycl PRIVATE sycl pthread m dl onemkl_blas_cublas)
7273
elseif (GGML_SYCL_TARGET STREQUAL "AMD")
7374
if (NOT GGML_SYCL_DEVICE_ARCH)
7475
message(ERROR "Can't enable SYCL hip backend, GGML_SYCL_DEVICE_ARCH has not been set.")

ggml/src/ggml-sycl/dpct/helper.hpp

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1689,9 +1689,14 @@ namespace dpct
16891689
auto data_a = get_memory<const Ta>(a);
16901690
auto data_b = get_memory<const Tb>(b);
16911691
auto data_c = get_memory<Tc>(c);
1692-
oneapi::mkl::blas::column_major::gemm(
1693-
q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda,
1694-
data_b, ldb, beta_value, data_c, ldc);
1692+
#ifdef GGML_SYCL_NVIDIA
1693+
oneapi::mkl::blas::column_major::gemm(oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q },
1694+
a_trans, b_trans, m, n, k, alpha_value, data_a, lda, data_b, ldb,
1695+
beta_value, data_c, ldc);
1696+
#else
1697+
oneapi::mkl::blas::column_major::gemm(q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda, data_b, ldb,
1698+
beta_value, data_c, ldc);
1699+
#endif
16951700
}
16961701

16971702
template <typename VecT, class BinaryOperation, class = void>
@@ -1754,14 +1759,22 @@ namespace dpct
17541759
matrix_info->ld_info[2] = ldc;
17551760
matrix_info->groupsize_info = batch_size;
17561761

1762+
#ifdef GGML_SYCL_NVIDIA
1763+
sycl::event e = oneapi::mkl::blas::column_major::gemm_batch(
1764+
oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q }, matrix_info->transpose_info,
1765+
matrix_info->transpose_info + 1, matrix_info->size_info, matrix_info->size_info + 1,
1766+
matrix_info->size_info + 2, matrix_info->value_info, reinterpret_cast<const Ta **>(a),
1767+
matrix_info->ld_info, reinterpret_cast<const Tb **>(b), matrix_info->ld_info + 1,
1768+
matrix_info->value_info + 1, reinterpret_cast<Tc **>(c), matrix_info->ld_info + 2, 1,
1769+
&(matrix_info->groupsize_info));
1770+
#else
17571771
sycl::event e = oneapi::mkl::blas::column_major::gemm_batch(
1758-
q, matrix_info->transpose_info, matrix_info->transpose_info + 1,
1759-
matrix_info->size_info, matrix_info->size_info + 1,
1760-
matrix_info->size_info + 2, matrix_info->value_info,
1761-
reinterpret_cast<const Ta **>(a), matrix_info->ld_info,
1762-
reinterpret_cast<const Tb **>(b), matrix_info->ld_info + 1,
1763-
matrix_info->value_info + 1, reinterpret_cast<Tc **>(c),
1772+
q, matrix_info->transpose_info, matrix_info->transpose_info + 1, matrix_info->size_info,
1773+
matrix_info->size_info + 1, matrix_info->size_info + 2, matrix_info->value_info,
1774+
reinterpret_cast<const Ta **>(a), matrix_info->ld_info, reinterpret_cast<const Tb **>(b),
1775+
matrix_info->ld_info + 1, matrix_info->value_info + 1, reinterpret_cast<Tc **>(c),
17641776
matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info));
1777+
#endif
17651778

17661779
q.submit([&](sycl::handler &cgh)
17671780
{
@@ -1783,10 +1796,16 @@ namespace dpct
17831796
auto data_a = get_memory<const Ta>(a);
17841797
auto data_b = get_memory<const Tb>(b);
17851798
auto data_c = get_memory<Tc>(c);
1799+
#ifdef GGML_SYCL_NVIDIA
17861800
oneapi::mkl::blas::column_major::gemm_batch(
1787-
q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda,
1788-
stride_a, data_b, ldb, stride_b, beta_value,
1789-
data_c, ldc, stride_c, batch_size);
1801+
oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q }, a_trans, b_trans, m, n, k,
1802+
alpha_value, data_a, lda, stride_a, data_b, ldb, stride_b, beta_value, data_c, ldc, stride_c,
1803+
batch_size);
1804+
#else
1805+
oneapi::mkl::blas::column_major::gemm_batch(q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda,
1806+
stride_a, data_b, ldb, stride_b, beta_value, data_c, ldc,
1807+
stride_c, batch_size);
1808+
#endif
17901809
}
17911810

17921811
} // namespace detail

ggml/src/ggml-sycl/ggml-sycl.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2573,12 +2573,17 @@ inline void ggml_sycl_op_mul_mat_sycl(
25732573
const float alpha = 1.0f;
25742574
const float beta = 0.0f;
25752575
#if !GGML_SYCL_DNNL
2576+
# ifdef GGML_SYCL_NVIDIA
25762577
SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm(
2577-
*stream, oneapi::mkl::transpose::trans,
2578-
oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10,
2579-
dpct::get_value(&alpha, *stream), src0_ddf_i, ne00,
2580-
src1_ddf1_i, ne10, dpct::get_value(&beta, *stream),
2578+
oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ *stream }, oneapi::mkl::transpose::trans,
2579+
oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10, dpct::get_value(&alpha, *stream), src0_ddf_i,
2580+
ne00, src1_ddf1_i, ne10, dpct::get_value(&beta, *stream), dst_dd_i, ldc)));
2581+
# else
2582+
SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm(
2583+
*stream, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10,
2584+
dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10, dpct::get_value(&beta, *stream),
25812585
dst_dd_i, ldc)));
2586+
# endif
25822587
#else
25832588
auto dnnl_stream = ctx.stream_dnnl(stream);
25842589
DnnlGemmWrapper::row_gemm(dnnl_stream, false, true, src1_ncols, row_diff, ne10, src1_ddf1_i, DnnlGemmWrapper::to_dt<float>(),

ggml/src/ggml-sycl/outprod.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -40,14 +40,14 @@ void ggml_sycl_op_out_prod(ggml_backend_sycl_context& ctx, const ggml_tensor* sr
4040

4141
try {
4242
// Perform matrix multiplication using oneMKL GEMM
43-
oneapi::mkl::blas::column_major::gemm(*stream,
44-
oneapi::mkl::transpose::nontrans, src1_op,
45-
ne0, ne1, ne01,
46-
alpha,
47-
src0_d, ne00,
48-
src1_d, ldb,
49-
beta,
50-
dst_d, ne0);
43+
#ifdef GGML_SYCL_NVIDIA
44+
oneapi::mkl::blas::column_major::gemm(oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ *stream },
45+
oneapi::mkl::transpose::nontrans, src1_op, ne0, ne1, ne01, alpha, src0_d,
46+
ne00, src1_d, ldb, beta, dst_d, ne0);
47+
#else
48+
oneapi::mkl::blas::column_major::gemm(*stream, oneapi::mkl::transpose::nontrans, src1_op, ne0, ne1, ne01, alpha,
49+
src0_d, ne00, src1_d, ldb, beta, dst_d, ne0);
50+
#endif
5151
}
5252
catch (sycl::exception const& exc) {
5353
std::cerr << exc.what() << std::endl;

0 commit comments

Comments
 (0)