Skip to content

Commit 59acf14

Browse files
committed
Implement a workaround to gemm issue in OneMKL
1 parent 90d67f5 commit 59acf14

File tree

4 files changed

+57
-0
lines changed

4 files changed

+57
-0
lines changed

dpnp/backend/extensions/blas/blas_py.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,13 @@ PYBIND11_MODULE(_blas_impl, m)
127127
py::arg("resultC"), py::arg("depends") = py::list());
128128
}
129129

130+
{
131+
m.def("_is_lnl_arl_architecture", &blas_ns::_is_lnl_arl_architecture,
132+
"Return ``True`` if SYCL device belongs to either Lunar Lake or "
133+
"Arrow Lake architecture",
134+
py::arg("device"));
135+
}
136+
130137
{
131138
m.def("_gemm_batch", &blas_ns::gemm_batch,
132139
"Call `gemm_batch` from OneMKL BLAS library to compute "

dpnp/backend/extensions/blas/gemm.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,25 @@ std::tuple<sycl::event, sycl::event, bool>
323323
return std::make_tuple(args_ev, gemm_ev, is_row_major);
324324
}
325325

326+
bool _is_lnl_arl_architecture(sycl::device &dev)
327+
{
328+
#if !defined(USE_ONEMKL_CUBLAS)
329+
if (dev.ext_oneapi_architecture_is(
330+
sycl::ext::oneapi::experimental::architecture::
331+
intel_gpu_20_4_4)) /* Lunar Lake */
332+
{
333+
return true;
334+
}
335+
else if (dev.ext_oneapi_architecture_is(
336+
sycl::ext::oneapi::experimental::architecture::
337+
intel_gpu_12_74_4)) /* Arrow Lake */
338+
{
339+
return true;
340+
}
341+
#endif // !defined(USE_ONEMKL_CUBLAS)
342+
return false;
343+
}
344+
326345
template <typename fnT, typename Tab, typename Tc>
327346
struct GemmContigFactory
328347
{

dpnp/backend/extensions/blas/gemm.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ extern std::tuple<sycl::event, sycl::event, bool>
3939
const dpctl::tensor::usm_ndarray &resultC,
4040
const std::vector<sycl::event> &depends);
4141

42+
extern bool _is_lnl_arl_architecture(sycl::device &dev);
43+
4244
extern std::tuple<sycl::event, sycl::event, bool>
4345
gemm_batch(sycl::queue &exec_q,
4446
const dpctl::tensor::usm_ndarray &matrixA,

dpnp/dpnp_utils/dpnp_utils_linearalgebra.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -894,6 +894,35 @@ def dpnp_matmul(
894894
)
895895
_manager.add_event_pair(ht_ev, gemv_ev)
896896
elif call_flag == "gemm":
897+
# MKLD-17976: due to known issue in OneMKL on Lunar Lake and
898+
# Arrow Lake architectures, it forces to implement a temporary
899+
# workaround with extra copying of an input array in case when
900+
# it has a small size and non-zero offset
901+
# TODO: remove the workaraound once OneMKL issue is resolved
902+
if (
903+
compute_dtype != dpnp.float32
904+
and bi._is_lnl_arl_architecture(exec_q.get_sycl_device())
905+
):
906+
907+
def _need_to_copy(a):
908+
a_usm = dpnp.get_usm_ndarray(a)
909+
if a_usm._element_offset > 0 and a_usm.size < 16:
910+
return True
911+
return False
912+
913+
x1 = _copy_array(
914+
x1,
915+
copy_flag=_need_to_copy(x1),
916+
dtype=compute_dtype,
917+
order=res_order,
918+
)
919+
x2 = _copy_array(
920+
x2,
921+
copy_flag=_need_to_copy(x2),
922+
dtype=compute_dtype,
923+
order=res_order,
924+
)
925+
897926
result = _gemm_matmul(
898927
exec_q,
899928
x1,

0 commit comments

Comments
 (0)