@@ -1699,9 +1699,14 @@ namespace dpct
1699
1699
auto data_a = get_memory<const Ta>(a);
1700
1700
auto data_b = get_memory<const Tb>(b);
1701
1701
auto data_c = get_memory<Tc>(c);
1702
- oneapi::mkl::blas::column_major::gemm (
1703
- q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda,
1704
- data_b, ldb, beta_value, data_c, ldc);
1702
+ #ifdef GGML_SYCL_NVIDIA
1703
+ oneapi::mkl::blas::column_major::gemm (oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q },
1704
+ a_trans, b_trans, m, n, k, alpha_value, data_a, lda, data_b, ldb,
1705
+ beta_value, data_c, ldc);
1706
+ #else
1707
+ oneapi::mkl::blas::column_major::gemm (q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda, data_b, ldb,
1708
+ beta_value, data_c, ldc);
1709
+ #endif
1705
1710
}
1706
1711
1707
1712
template <typename VecT, class BinaryOperation , class = void >
@@ -1764,14 +1769,22 @@ namespace dpct
1764
1769
matrix_info->ld_info [2 ] = ldc;
1765
1770
matrix_info->groupsize_info = batch_size;
1766
1771
1772
+ #ifdef GGML_SYCL_NVIDIA
1773
+ sycl::event e = oneapi::mkl::blas::column_major::gemm_batch (
1774
+ oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q }, matrix_info->transpose_info ,
1775
+ matrix_info->transpose_info + 1 , matrix_info->size_info , matrix_info->size_info + 1 ,
1776
+ matrix_info->size_info + 2 , matrix_info->value_info , reinterpret_cast <const Ta **>(a),
1777
+ matrix_info->ld_info , reinterpret_cast <const Tb **>(b), matrix_info->ld_info + 1 ,
1778
+ matrix_info->value_info + 1 , reinterpret_cast <Tc **>(c), matrix_info->ld_info + 2 , 1 ,
1779
+ &(matrix_info->groupsize_info ));
1780
+ #else
1767
1781
sycl::event e = oneapi::mkl::blas::column_major::gemm_batch (
1768
- q, matrix_info->transpose_info , matrix_info->transpose_info + 1 ,
1769
- matrix_info->size_info , matrix_info->size_info + 1 ,
1770
- matrix_info->size_info + 2 , matrix_info->value_info ,
1771
- reinterpret_cast <const Ta **>(a), matrix_info->ld_info ,
1772
- reinterpret_cast <const Tb **>(b), matrix_info->ld_info + 1 ,
1773
- matrix_info->value_info + 1 , reinterpret_cast <Tc **>(c),
1782
+ q, matrix_info->transpose_info , matrix_info->transpose_info + 1 , matrix_info->size_info ,
1783
+ matrix_info->size_info + 1 , matrix_info->size_info + 2 , matrix_info->value_info ,
1784
+ reinterpret_cast <const Ta **>(a), matrix_info->ld_info , reinterpret_cast <const Tb **>(b),
1785
+ matrix_info->ld_info + 1 , matrix_info->value_info + 1 , reinterpret_cast <Tc **>(c),
1774
1786
matrix_info->ld_info + 2 , 1 , &(matrix_info->groupsize_info ));
1787
+ #endif
1775
1788
1776
1789
q.submit ([&](sycl::handler &cgh)
1777
1790
{
@@ -1793,10 +1806,16 @@ namespace dpct
1793
1806
auto data_a = get_memory<const Ta>(a);
1794
1807
auto data_b = get_memory<const Tb>(b);
1795
1808
auto data_c = get_memory<Tc>(c);
1809
+ #ifdef GGML_SYCL_NVIDIA
1796
1810
oneapi::mkl::blas::column_major::gemm_batch (
1797
- q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda,
1798
- stride_a, data_b, ldb, stride_b, beta_value,
1799
- data_c, ldc, stride_c, batch_size);
1811
+ oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q }, a_trans, b_trans, m, n, k,
1812
+ alpha_value, data_a, lda, stride_a, data_b, ldb, stride_b, beta_value, data_c, ldc, stride_c,
1813
+ batch_size);
1814
+ #else
1815
+ oneapi::mkl::blas::column_major::gemm_batch (q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda,
1816
+ stride_a, data_b, ldb, stride_b, beta_value, data_c, ldc,
1817
+ stride_c, batch_size);
1818
+ #endif
1800
1819
}
1801
1820
1802
1821
} // namespace detail
0 commit comments