Skip to content

Commit db892e1

Browse files
s-Nickarthw
authored andcommitted
SYCL : Move to compile time oneMKL interface backend selection for NVIDIA backend (ggml-org#10584)
* [SYCL] Move to Compile Time backend selection on oneMKL Interface for NVIDIA backend Move to compile time selection to backend to avoid latency at run time. Add it to all mkl gemm calls and only for NVIDIA backend. Signed-off-by: nscipione <[email protected]> * Formatting * Address PR comments to increase readibility --------- Signed-off-by: nscipione <[email protected]>
1 parent 7fd12a6 commit db892e1

File tree

3 files changed

+42
-17
lines changed

3 files changed

+42
-17
lines changed

ggml/src/ggml-sycl/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,8 @@ else()
6868
target_link_libraries(ggml-sycl PRIVATE sycl OpenCL mkl_core pthread m dl mkl_sycl_blas mkl_intel_ilp64 mkl_tbb_thread)
6969
elseif (GGML_SYCL_TARGET STREQUAL "NVIDIA")
7070
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsycl-targets=nvptx64-nvidia-cuda")
71-
target_link_libraries(ggml-sycl PRIVATE sycl pthread m dl onemkl)
71+
add_compile_definitions(GGML_SYCL_NVIDIA)
72+
target_link_libraries(ggml-sycl PRIVATE sycl pthread m dl onemkl_blas_cublas)
7273
elseif (GGML_SYCL_TARGET STREQUAL "AMD")
7374
if (NOT GGML_SYCL_DEVICE_ARCH)
7475
message(ERROR "Can't enable SYCL hip backend, GGML_SYCL_DEVICE_ARCH has not been set.")

ggml/src/ggml-sycl/dpct.hpp

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1699,9 +1699,14 @@ namespace dpct
16991699
auto data_a = get_memory<const Ta>(a);
17001700
auto data_b = get_memory<const Tb>(b);
17011701
auto data_c = get_memory<Tc>(c);
1702-
oneapi::mkl::blas::column_major::gemm(
1703-
q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda,
1704-
data_b, ldb, beta_value, data_c, ldc);
1702+
#ifdef GGML_SYCL_NVIDIA
1703+
oneapi::mkl::blas::column_major::gemm(oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q },
1704+
a_trans, b_trans, m, n, k, alpha_value, data_a, lda, data_b, ldb,
1705+
beta_value, data_c, ldc);
1706+
#else
1707+
oneapi::mkl::blas::column_major::gemm(q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda, data_b, ldb,
1708+
beta_value, data_c, ldc);
1709+
#endif
17051710
}
17061711

17071712
template <typename VecT, class BinaryOperation, class = void>
@@ -1764,14 +1769,22 @@ namespace dpct
17641769
matrix_info->ld_info[2] = ldc;
17651770
matrix_info->groupsize_info = batch_size;
17661771

1772+
#ifdef GGML_SYCL_NVIDIA
1773+
sycl::event e = oneapi::mkl::blas::column_major::gemm_batch(
1774+
oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q }, matrix_info->transpose_info,
1775+
matrix_info->transpose_info + 1, matrix_info->size_info, matrix_info->size_info + 1,
1776+
matrix_info->size_info + 2, matrix_info->value_info, reinterpret_cast<const Ta **>(a),
1777+
matrix_info->ld_info, reinterpret_cast<const Tb **>(b), matrix_info->ld_info + 1,
1778+
matrix_info->value_info + 1, reinterpret_cast<Tc **>(c), matrix_info->ld_info + 2, 1,
1779+
&(matrix_info->groupsize_info));
1780+
#else
17671781
sycl::event e = oneapi::mkl::blas::column_major::gemm_batch(
1768-
q, matrix_info->transpose_info, matrix_info->transpose_info + 1,
1769-
matrix_info->size_info, matrix_info->size_info + 1,
1770-
matrix_info->size_info + 2, matrix_info->value_info,
1771-
reinterpret_cast<const Ta **>(a), matrix_info->ld_info,
1772-
reinterpret_cast<const Tb **>(b), matrix_info->ld_info + 1,
1773-
matrix_info->value_info + 1, reinterpret_cast<Tc **>(c),
1782+
q, matrix_info->transpose_info, matrix_info->transpose_info + 1, matrix_info->size_info,
1783+
matrix_info->size_info + 1, matrix_info->size_info + 2, matrix_info->value_info,
1784+
reinterpret_cast<const Ta **>(a), matrix_info->ld_info, reinterpret_cast<const Tb **>(b),
1785+
matrix_info->ld_info + 1, matrix_info->value_info + 1, reinterpret_cast<Tc **>(c),
17741786
matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info));
1787+
#endif
17751788

17761789
q.submit([&](sycl::handler &cgh)
17771790
{
@@ -1793,10 +1806,16 @@ namespace dpct
17931806
auto data_a = get_memory<const Ta>(a);
17941807
auto data_b = get_memory<const Tb>(b);
17951808
auto data_c = get_memory<Tc>(c);
1809+
#ifdef GGML_SYCL_NVIDIA
17961810
oneapi::mkl::blas::column_major::gemm_batch(
1797-
q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda,
1798-
stride_a, data_b, ldb, stride_b, beta_value,
1799-
data_c, ldc, stride_c, batch_size);
1811+
oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q }, a_trans, b_trans, m, n, k,
1812+
alpha_value, data_a, lda, stride_a, data_b, ldb, stride_b, beta_value, data_c, ldc, stride_c,
1813+
batch_size);
1814+
#else
1815+
oneapi::mkl::blas::column_major::gemm_batch(q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda,
1816+
stride_a, data_b, ldb, stride_b, beta_value, data_c, ldc,
1817+
stride_c, batch_size);
1818+
#endif
18001819
}
18011820

18021821
} // namespace detail

ggml/src/ggml-sycl/ggml-sycl.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2434,12 +2434,17 @@ inline void ggml_sycl_op_mul_mat_sycl(
24342434
const float alpha = 1.0f;
24352435
const float beta = 0.0f;
24362436
#if !GGML_SYCL_DNNL
2437+
# ifdef GGML_SYCL_NVIDIA
24372438
SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm(
2438-
*stream, oneapi::mkl::transpose::trans,
2439-
oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10,
2440-
dpct::get_value(&alpha, *stream), src0_ddf_i, ne00,
2441-
src1_ddf1_i, ne10, dpct::get_value(&beta, *stream),
2439+
oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ *stream }, oneapi::mkl::transpose::trans,
2440+
oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10, dpct::get_value(&alpha, *stream), src0_ddf_i,
2441+
ne00, src1_ddf1_i, ne10, dpct::get_value(&beta, *stream), dst_dd_i, ldc)));
2442+
# else
2443+
SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm(
2444+
*stream, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10,
2445+
dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10, dpct::get_value(&beta, *stream),
24422446
dst_dd_i, ldc)));
2447+
# endif
24432448
#else
24442449
auto dnnl_stream = ctx.stream_dnnl(stream);
24452450
DnnlGemmWrapper::row_gemm(dnnl_stream, false, true, src1_ncols, row_diff, ne10, src1_ddf1_i, DnnlGemmWrapper::to_dt<float>(),

0 commit comments

Comments
 (0)