Skip to content

Commit f8b0ce8

Browse files
authored
Function "linalg.qr" raised an exception with mode='reduced' (#1148)
* DPNP function "linalg.qr" raised an exception when mode='reduced'.
1 parent 0234590 commit f8b0ce8

File tree

5 files changed

+32
-38
lines changed

5 files changed

+32
-38
lines changed

dpnp/backend/kernels/dpnp_krnl_linalg.cpp

Lines changed: 20 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -662,35 +662,36 @@ DPCTLSyclEventRef dpnp_qr_c(DPCTLSyclQueueRef q_ref,
662662
}
663663
}
664664

665-
DPNPC_ptr_adapter<_ComputeDT> result1_ptr(q_ref, result1, size_m * size_m, true, true);
666-
DPNPC_ptr_adapter<_ComputeDT> result2_ptr(q_ref, result2, size_m * size_n, true, true);
667-
DPNPC_ptr_adapter<_ComputeDT> result3_ptr(q_ref, result3, std::min(size_m, size_n), true, true);
665+
const size_t min_size_m_n = std::min<size_t>(size_m, size_n);
666+
DPNPC_ptr_adapter<_ComputeDT> result1_ptr(q_ref, result1, size_m * min_size_m_n, true, true);
667+
DPNPC_ptr_adapter<_ComputeDT> result2_ptr(q_ref, result2, min_size_m_n * size_n, true, true);
668+
DPNPC_ptr_adapter<_ComputeDT> result3_ptr(q_ref, result3, min_size_m_n, true, true);
668669
_ComputeDT* res_q = result1_ptr.get_ptr();
669670
_ComputeDT* res_r = result2_ptr.get_ptr();
670671
_ComputeDT* tau = result3_ptr.get_ptr();
671672

672673
const std::int64_t lda = size_m;
673674

674-
const std::int64_t geqrf_scratchpad_size =
675-
mkl_lapack::geqrf_scratchpad_size<_ComputeDT>(q, size_m, size_n, lda);
675+
const std::int64_t geqrf_scratchpad_size = mkl_lapack::geqrf_scratchpad_size<_ComputeDT>(q, size_m, size_n, lda);
676676

677677
_ComputeDT* geqrf_scratchpad =
678678
reinterpret_cast<_ComputeDT*>(sycl::malloc_shared(geqrf_scratchpad_size * sizeof(_ComputeDT), q));
679679

680680
std::vector<sycl::event> depends(1);
681681
set_barrier_event(q, depends);
682682

683-
event =
684-
mkl_lapack::geqrf(q, size_m, size_n, in_a, lda, tau, geqrf_scratchpad, geqrf_scratchpad_size, depends);
685-
683+
event = mkl_lapack::geqrf(q, size_m, size_n, in_a, lda, tau, geqrf_scratchpad, geqrf_scratchpad_size, depends);
686684
event.wait();
687685

688-
verbose_print("oneapi::mkl::lapack::geqrf", depends.front(), event);
686+
if (!depends.empty()) {
687+
verbose_print("oneapi::mkl::lapack::geqrf", depends.front(), event);
688+
}
689689

690690
sycl::free(geqrf_scratchpad, q);
691691

692692
// R
693-
for (size_t i = 0; i < size_m; ++i)
693+
size_t mrefl = min_size_m_n;
694+
for (size_t i = 0; i < mrefl; ++i)
694695
{
695696
for (size_t j = 0; j < size_n; ++j)
696697
{
@@ -706,37 +707,30 @@ DPCTLSyclEventRef dpnp_qr_c(DPCTLSyclQueueRef q_ref,
706707
}
707708

708709
// Q
709-
const size_t nrefl = std::min<size_t>(size_m, size_n);
710+
const size_t nrefl = min_size_m_n;
710711
const std::int64_t orgqr_scratchpad_size =
711-
mkl_lapack::orgqr_scratchpad_size<_ComputeDT>(q, size_m, size_m, nrefl, lda);
712+
mkl_lapack::orgqr_scratchpad_size<_ComputeDT>(q, size_m, nrefl, nrefl, lda);
712713

713714
_ComputeDT* orgqr_scratchpad =
714715
reinterpret_cast<_ComputeDT*>(sycl::malloc_shared(orgqr_scratchpad_size * sizeof(_ComputeDT), q));
715716

716-
depends.clear();
717717
set_barrier_event(q, depends);
718718

719-
event = mkl_lapack::orgqr(
720-
q, size_m, size_m, nrefl, in_a, lda, tau, orgqr_scratchpad, orgqr_scratchpad_size, depends);
721-
719+
event =
720+
mkl_lapack::orgqr(q, size_m, nrefl, nrefl, in_a, lda, tau, orgqr_scratchpad, orgqr_scratchpad_size, depends);
722721
event.wait();
723722

724-
verbose_print("oneapi::mkl::lapack::orgqr", depends.front(), event);
723+
if (!depends.empty()) {
724+
verbose_print("oneapi::mkl::lapack::orgqr", depends.front(), event);
725+
}
725726

726727
sycl::free(orgqr_scratchpad, q);
727728

728729
for (size_t i = 0; i < size_m; ++i)
729730
{
730-
for (size_t j = 0; j < size_m; ++j)
731+
for (size_t j = 0; j < nrefl; ++j)
731732
{
732-
if (j < nrefl)
733-
{
734-
res_q[i * size_m + j] = in_a[j * size_m + i];
735-
}
736-
else
737-
{
738-
res_q[i * size_m + j] = _ComputeDT(0);
739-
}
733+
res_q[i * nrefl + j] = in_a[j * size_m + i];
740734
}
741735
}
742736

dpnp/linalg/dpnp_algo_linalg.pyx

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -308,13 +308,14 @@ cpdef object dpnp_norm(object input, ord=None, axis=None):
308308
cpdef tuple dpnp_qr(utils.dpnp_descriptor x1, str mode):
309309
cdef size_t size_m = x1.shape[0]
310310
cdef size_t size_n = x1.shape[1]
311-
cdef size_t size_tau = min(size_m, size_n)
311+
cdef size_t min_m_n = min(size_m, size_n)
312+
cdef size_t size_tau = min_m_n
312313

313314
cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(x1.dtype)
314315
cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_QR, param1_type, param1_type)
315316

316-
cdef utils.dpnp_descriptor res_q = utils.create_output_descriptor((size_m, size_m), kernel_data.return_type, None)
317-
cdef utils.dpnp_descriptor res_r = utils.create_output_descriptor((size_m, size_n), kernel_data.return_type, None)
317+
cdef utils.dpnp_descriptor res_q = utils.create_output_descriptor((size_m, min_m_n), kernel_data.return_type, None)
318+
cdef utils.dpnp_descriptor res_r = utils.create_output_descriptor((min_m_n, size_n), kernel_data.return_type, None)
318319
cdef utils.dpnp_descriptor tau = utils.create_output_descriptor((size_tau, ), kernel_data.return_type, None)
319320

320321
cdef custom_linalg_1in_3out_shape_t func = < custom_linalg_1in_3out_shape_t > kernel_data.ptr

dpnp/linalg/dpnp_iface_linalg.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,7 @@ def qr(x1, mode='reduced'):
391391
Limitations
392392
-----------
393393
Input array is supported as :obj:`dpnp.ndarray`.
394-
Parameter mode='complete' is supported.
394+
Parameter mode='reduced' is supported.
395395
396396
"""
397397

@@ -400,7 +400,7 @@ def qr(x1, mode='reduced'):
400400
if mode != 'reduced':
401401
pass
402402
else:
403-
result_tup = dpnp_qr(x1, mode)
403+
result_tup = dpnp_qr(x1_desc, mode)
404404

405405
return result_tup
406406

tests/skipped_tests_gpu.tbl

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -263,10 +263,6 @@ tests/test_linalg.py::test_eig_arange[16-float64]
263263
tests/test_linalg.py::test_eig_arange[16-float32]
264264
tests/test_linalg.py::test_eig_arange[16-int64]
265265
tests/test_linalg.py::test_eig_arange[16-int32]
266-
tests/test_linalg.py::test_qr[(16,16)-float64]
267-
tests/test_linalg.py::test_qr[(16,16)-float32]
268-
tests/test_linalg.py::test_qr[(16,16)-int64]
269-
tests/test_linalg.py::test_qr[(16,16)-int32]
270266
tests/test_random.py::test_randn_normal_distribution
271267
tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_multidim_outer
272268
tests/third_party/cupy/random_tests/test_sample.py::TestRandintDtype::test_dtype

tests/test_linalg.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -219,12 +219,15 @@ def test_norm3(array, ord, axis):
219219
@pytest.mark.parametrize("shape",
220220
[(2, 2), (3, 4), (5, 3), (16, 16)],
221221
ids=['(2,2)', '(3,4)', '(5,3)', '(16,16)'])
222-
def test_qr(type, shape):
222+
@pytest.mark.parametrize("mode",
223+
['complete', 'reduced'],
224+
ids=['complete', 'reduced'])
225+
def test_qr(type, shape, mode):
223226
a = numpy.arange(shape[0] * shape[1], dtype=type).reshape(shape)
224227
ia = inp.array(a)
225228

226-
np_q, np_r = numpy.linalg.qr(a, "complete")
227-
dpnp_q, dpnp_r = inp.linalg.qr(ia, "complete")
229+
np_q, np_r = numpy.linalg.qr(a, mode)
230+
dpnp_q, dpnp_r = inp.linalg.qr(ia, mode)
228231

229232
assert (dpnp_q.dtype == np_q.dtype)
230233
assert (dpnp_r.dtype == np_r.dtype)

0 commit comments

Comments
 (0)