@@ -1689,9 +1689,14 @@ namespace dpct
1689
1689
auto data_a = get_memory<const Ta>(a);
1690
1690
auto data_b = get_memory<const Tb>(b);
1691
1691
auto data_c = get_memory<Tc>(c);
1692
- oneapi::mkl::blas::column_major::gemm (
1693
- q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda,
1694
- data_b, ldb, beta_value, data_c, ldc);
1692
+ #ifdef GGML_SYCL_NVIDIA
1693
+ oneapi::mkl::blas::column_major::gemm (oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q },
1694
+ a_trans, b_trans, m, n, k, alpha_value, data_a, lda, data_b, ldb,
1695
+ beta_value, data_c, ldc);
1696
+ #else
1697
+ oneapi::mkl::blas::column_major::gemm (q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda, data_b, ldb,
1698
+ beta_value, data_c, ldc);
1699
+ #endif
1695
1700
}
1696
1701
1697
1702
template <typename VecT, class BinaryOperation , class = void >
@@ -1754,14 +1759,22 @@ namespace dpct
1754
1759
matrix_info->ld_info [2 ] = ldc;
1755
1760
matrix_info->groupsize_info = batch_size;
1756
1761
1762
+ #ifdef GGML_SYCL_NVIDIA
1763
+ sycl::event e = oneapi::mkl::blas::column_major::gemm_batch (
1764
+ oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q }, matrix_info->transpose_info ,
1765
+ matrix_info->transpose_info + 1 , matrix_info->size_info , matrix_info->size_info + 1 ,
1766
+ matrix_info->size_info + 2 , matrix_info->value_info , reinterpret_cast <const Ta **>(a),
1767
+ matrix_info->ld_info , reinterpret_cast <const Tb **>(b), matrix_info->ld_info + 1 ,
1768
+ matrix_info->value_info + 1 , reinterpret_cast <Tc **>(c), matrix_info->ld_info + 2 , 1 ,
1769
+ &(matrix_info->groupsize_info ));
1770
+ #else
1757
1771
sycl::event e = oneapi::mkl::blas::column_major::gemm_batch (
1758
- q, matrix_info->transpose_info , matrix_info->transpose_info + 1 ,
1759
- matrix_info->size_info , matrix_info->size_info + 1 ,
1760
- matrix_info->size_info + 2 , matrix_info->value_info ,
1761
- reinterpret_cast <const Ta **>(a), matrix_info->ld_info ,
1762
- reinterpret_cast <const Tb **>(b), matrix_info->ld_info + 1 ,
1763
- matrix_info->value_info + 1 , reinterpret_cast <Tc **>(c),
1772
+ q, matrix_info->transpose_info , matrix_info->transpose_info + 1 , matrix_info->size_info ,
1773
+ matrix_info->size_info + 1 , matrix_info->size_info + 2 , matrix_info->value_info ,
1774
+ reinterpret_cast <const Ta **>(a), matrix_info->ld_info , reinterpret_cast <const Tb **>(b),
1775
+ matrix_info->ld_info + 1 , matrix_info->value_info + 1 , reinterpret_cast <Tc **>(c),
1764
1776
matrix_info->ld_info + 2 , 1 , &(matrix_info->groupsize_info ));
1777
+ #endif
1765
1778
1766
1779
q.submit ([&](sycl::handler &cgh)
1767
1780
{
@@ -1783,10 +1796,16 @@ namespace dpct
1783
1796
auto data_a = get_memory<const Ta>(a);
1784
1797
auto data_b = get_memory<const Tb>(b);
1785
1798
auto data_c = get_memory<Tc>(c);
1799
+ #ifdef GGML_SYCL_NVIDIA
1786
1800
oneapi::mkl::blas::column_major::gemm_batch (
1787
- q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda,
1788
- stride_a, data_b, ldb, stride_b, beta_value,
1789
- data_c, ldc, stride_c, batch_size);
1801
+ oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q }, a_trans, b_trans, m, n, k,
1802
+ alpha_value, data_a, lda, stride_a, data_b, ldb, stride_b, beta_value, data_c, ldc, stride_c,
1803
+ batch_size);
1804
+ #else
1805
+ oneapi::mkl::blas::column_major::gemm_batch (q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda,
1806
+ stride_a, data_b, ldb, stride_b, beta_value, data_c, ldc,
1807
+ stride_c, batch_size);
1808
+ #endif
1790
1809
}
1791
1810
1792
1811
} // namespace detail
0 commit comments