Skip to content

Commit 8b6a94c

Browse files
authored
Merge fe4c862 into 48515c8
2 parents 48515c8 + fe4c862 commit 8b6a94c

File tree

5 files changed

+65
-12
lines changed

5 files changed

+65
-12
lines changed

dpnp/backend/extensions/blas/blas_py.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,13 @@ PYBIND11_MODULE(_blas_impl, m)
134134
py::arg("device"));
135135
}
136136

137+
{
138+
m.def("_is_16_bytes_aligned", &blas_ns::_is_16_bytes_aligned,
139+
"Return ``True`` if pointer on USM allocation has 16 bytes "
140+
"alignment in memory",
141+
py::arg("a"));
142+
}
143+
137144
{
138145
m.def("_gemm_batch", &blas_ns::gemm_batch,
139146
"Call `gemm_batch` from OneMKL BLAS library to compute "

dpnp/backend/extensions/blas/gemm.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include <pybind11/pybind11.h>
2727

2828
// dpctl tensor headers
29+
#include "kernels/alignment.hpp"
2930
#include "utils/memory_overlap.hpp"
3031
#include "utils/output_validation.hpp"
3132
#include "utils/type_utils.hpp"
@@ -339,6 +340,12 @@ bool _is_lnl_bm_architecture(const sycl::device &dev)
339340
return false;
340341
}
341342

343+
bool _is_16_bytes_aligned(const dpctl::tensor::usm_ndarray &a)
344+
{
345+
return dpctl::tensor::kernels::alignment_utils::is_aligned<16>(
346+
a.get_data());
347+
}
348+
342349
template <typename fnT, typename Tab, typename Tc>
343350
struct GemmContigFactory
344351
{

dpnp/backend/extensions/blas/gemm.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ extern std::tuple<sycl::event, sycl::event, bool>
4040
const std::vector<sycl::event> &depends);
4141

4242
extern bool _is_lnl_bm_architecture(const sycl::device &dev);
43+
extern bool _is_16_bytes_aligned(const dpctl::tensor::usm_ndarray &a);
4344

4445
extern std::tuple<sycl::event, sycl::event, bool>
4546
gemm_batch(sycl::queue &exec_q,

dpnp/dpnp_utils/dpnp_utils_linearalgebra.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -897,27 +897,19 @@ def dpnp_matmul(
897897
# MKLD-17976: due to known issue in OneMKL on Lunar Lake and
898898
# Battlemage G21 Intel GPU architectures, it forces
899899
# to implement a temporary workaround with extra copying of
900-
# an input array in case when it has a small size and
901-
# non-zero offset
902-
# The issue was detected by failing tests for eig/eigh
900+
# an input array in case when it does not have 16 bytes
901+
# alignment in the memory.
903902
# TODO: remove the workaround once OneMKL issue is resolved
904903
if bi._is_lnl_bm_architecture(exec_q.get_sycl_device()):
905-
906-
def _need_to_copy(a):
907-
a_usm = dpnp.get_usm_ndarray(a)
908-
if a_usm._element_offset > 0 and a_usm.size < 16:
909-
return True
910-
return False
911-
912904
x1 = _copy_array(
913905
x1,
914-
copy_flag=_need_to_copy(x1),
906+
copy_flag=bi._is_16_bytes_aligned(x1),
915907
dtype=compute_dtype,
916908
order=res_order,
917909
)
918910
x2 = _copy_array(
919911
x2,
920-
copy_flag=_need_to_copy(x2),
912+
copy_flag=bi._is_16_bytes_aligned(x2),
921913
dtype=compute_dtype,
922914
order=res_order,
923915
)
@@ -929,6 +921,26 @@ def _need_to_copy(a):
929921
result,
930922
)
931923
else: # call_flag == "gemm_batch"
924+
# MKLD-17976: due to known issue in OneMKL on Lunar Lake and
925+
# Battlemage G21 Intel GPU architectures, it forces
926+
# to implement a temporary workaround with extra copying of
927+
# an input array in case when it does not have 16 bytes
928+
# alignment in the memory.
929+
# TODO: remove the workaround once OneMKL issue is resolved
930+
if bi._is_lnl_bm_architecture(exec_q.get_sycl_device()):
931+
x1 = _copy_array(
932+
x1,
933+
copy_flag=bi._is_16_bytes_aligned(x1),
934+
dtype=compute_dtype,
935+
order=res_order,
936+
)
937+
x2 = _copy_array(
938+
x2,
939+
copy_flag=bi._is_16_bytes_aligned(x2),
940+
dtype=compute_dtype,
941+
order=res_order,
942+
)
943+
932944
result = _gemm_batch_matmul(
933945
exec_q,
934946
x1,

tests/test_mathematical.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3824,6 +3824,32 @@ def test_matmul_alias(self):
38243824
result2 = dpnp.linalg.matmul(a, b)
38253825
assert_array_equal(result1, result2)
38263826

3827+
@pytest.mark.parametrize(
3828+
"sh1, sh2",
3829+
[
3830+
((2, 3, 3), (3, 3)),
3831+
((3, 4, 4, 4), (4, 4, 4)),
3832+
],
3833+
ids=["gemm", "gemm_batch"],
3834+
)
3835+
def test_matmul_with_offsets(self, sh1, sh2):
3836+
size1, size2 = numpy.prod(sh1, dtype=int), numpy.prod(sh2, dtype=int)
3837+
a = numpy.random.randint(-5, 5, size1).reshape(sh1)
3838+
b = numpy.random.randint(-5, 5, size2).reshape(sh2)
3839+
ia, ib = dpnp.array(a), dpnp.array(b)
3840+
3841+
result = ia[1] @ ib
3842+
expected = a[1] @ b
3843+
assert_array_equal(result, expected)
3844+
3845+
result = ib @ ia[1]
3846+
expected = b @ a[1]
3847+
assert_array_equal(result, expected)
3848+
3849+
result = ia[1] @ ia[1]
3850+
expected = a[1] @ a[1]
3851+
assert_array_equal(result, expected)
3852+
38273853

38283854
class TestMatmulInvalidCases:
38293855
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)