Skip to content

Commit 11d3a94

Browse files
committed
Update the w/a based on input from OneMKL team
1 parent 48515c8 commit 11d3a94

File tree

4 files changed

+39
-12
lines changed

4 files changed

+39
-12
lines changed

dpnp/backend/extensions/blas/blas_py.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,13 @@ PYBIND11_MODULE(_blas_impl, m)
134134
py::arg("device"));
135135
}
136136

137+
{
138+
m.def("_is_16_bytes_aligned", &blas_ns::_is_16_bytes_aligned,
139+
"Return ``True`` if pointer on USM allocation has 16 bytes "
140+
"alignment in memory",
141+
py::arg("a"));
142+
}
143+
137144
{
138145
m.def("_gemm_batch", &blas_ns::gemm_batch,
139146
"Call `gemm_batch` from OneMKL BLAS library to compute "

dpnp/backend/extensions/blas/gemm.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include <pybind11/pybind11.h>
2727

2828
// dpctl tensor headers
29+
#include "kernels/alignment.hpp"
2930
#include "utils/memory_overlap.hpp"
3031
#include "utils/output_validation.hpp"
3132
#include "utils/type_utils.hpp"
@@ -339,6 +340,12 @@ bool _is_lnl_bm_architecture(const sycl::device &dev)
339340
return false;
340341
}
341342

343+
bool _is_16_bytes_aligned(const dpctl::tensor::usm_ndarray &a)
344+
{
345+
return dpctl::tensor::kernels::alignment_utils::is_aligned<16>(
346+
a.get_data());
347+
}
348+
342349
template <typename fnT, typename Tab, typename Tc>
343350
struct GemmContigFactory
344351
{

dpnp/backend/extensions/blas/gemm.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ extern std::tuple<sycl::event, sycl::event, bool>
4040
const std::vector<sycl::event> &depends);
4141

4242
extern bool _is_lnl_bm_architecture(const sycl::device &dev);
43+
extern bool _is_16_bytes_aligned(const dpctl::tensor::usm_ndarray &a);
4344

4445
extern std::tuple<sycl::event, sycl::event, bool>
4546
gemm_batch(sycl::queue &exec_q,

dpnp/dpnp_utils/dpnp_utils_linearalgebra.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -897,27 +897,19 @@ def dpnp_matmul(
897897
# MKLD-17976: due to known issue in OneMKL on Lunar Lake and
898898
# Battlemage G21 Intel GPU architectures, it forces
899899
# to implement a temporary workaround with extra copying of
900-
# an input array in case when it has a small size and
901-
# non-zero offset
902-
# The issue was detected by failing tests for eig/eigh
900+
# an input array in case when it does not have 16 bytes
901+
# alignment in the memory.
903902
# TODO: remove the workaround once OneMKL issue is resolved
904903
if bi._is_lnl_bm_architecture(exec_q.get_sycl_device()):
905-
906-
def _need_to_copy(a):
907-
a_usm = dpnp.get_usm_ndarray(a)
908-
if a_usm._element_offset > 0 and a_usm.size < 16:
909-
return True
910-
return False
911-
912904
x1 = _copy_array(
913905
x1,
914-
copy_flag=_need_to_copy(x1),
906+
copy_flag=bi._is_16_bytes_aligned(x1),
915907
dtype=compute_dtype,
916908
order=res_order,
917909
)
918910
x2 = _copy_array(
919911
x2,
920-
copy_flag=_need_to_copy(x2),
912+
copy_flag=bi._is_16_bytes_aligned(x2),
921913
dtype=compute_dtype,
922914
order=res_order,
923915
)
@@ -929,6 +921,26 @@ def _need_to_copy(a):
929921
result,
930922
)
931923
else: # call_flag == "gemm_batch"
924+
# MKLD-17976: due to known issue in OneMKL on Lunar Lake and
925+
# Battlemage G21 Intel GPU architectures, it forces
926+
# to implement a temporary workaround with extra copying of
927+
# an input array in case when it does not have 16 bytes
928+
# alignment in the memory.
929+
# TODO: remove the workaround once OneMKL issue is resolved
930+
if bi._is_lnl_bm_architecture(exec_q.get_sycl_device()):
931+
x1 = _copy_array(
932+
x1,
933+
copy_flag=bi._is_16_bytes_aligned(x1),
934+
dtype=compute_dtype,
935+
order=res_order,
936+
)
937+
x2 = _copy_array(
938+
x2,
939+
copy_flag=bi._is_16_bytes_aligned(x2),
940+
dtype=compute_dtype,
941+
order=res_order,
942+
)
943+
932944
result = _gemm_batch_matmul(
933945
exec_q,
934946
x1,

0 commit comments

Comments
 (0)