@@ -662,35 +662,36 @@ DPCTLSyclEventRef dpnp_qr_c(DPCTLSyclQueueRef q_ref,
662
662
}
663
663
}
664
664
665
- DPNPC_ptr_adapter<_ComputeDT> result1_ptr (q_ref, result1, size_m * size_m, true , true );
666
- DPNPC_ptr_adapter<_ComputeDT> result2_ptr (q_ref, result2, size_m * size_n, true , true );
667
- DPNPC_ptr_adapter<_ComputeDT> result3_ptr (q_ref, result3, std::min (size_m, size_n), true , true );
665
+ const size_t min_size_m_n = std::min<size_t >(size_m, size_n);
666
+ DPNPC_ptr_adapter<_ComputeDT> result1_ptr (q_ref, result1, size_m * min_size_m_n, true , true );
667
+ DPNPC_ptr_adapter<_ComputeDT> result2_ptr (q_ref, result2, min_size_m_n * size_n, true , true );
668
+ DPNPC_ptr_adapter<_ComputeDT> result3_ptr (q_ref, result3, min_size_m_n, true , true );
668
669
_ComputeDT* res_q = result1_ptr.get_ptr ();
669
670
_ComputeDT* res_r = result2_ptr.get_ptr ();
670
671
_ComputeDT* tau = result3_ptr.get_ptr ();
671
672
672
673
const std::int64_t lda = size_m;
673
674
674
- const std::int64_t geqrf_scratchpad_size =
675
- mkl_lapack::geqrf_scratchpad_size<_ComputeDT>(q, size_m, size_n, lda);
675
+ const std::int64_t geqrf_scratchpad_size = mkl_lapack::geqrf_scratchpad_size<_ComputeDT>(q, size_m, size_n, lda);
676
676
677
677
_ComputeDT* geqrf_scratchpad =
678
678
reinterpret_cast <_ComputeDT*>(sycl::malloc_shared (geqrf_scratchpad_size * sizeof (_ComputeDT), q));
679
679
680
680
std::vector<sycl::event> depends (1 );
681
681
set_barrier_event (q, depends);
682
682
683
- event =
684
- mkl_lapack::geqrf (q, size_m, size_n, in_a, lda, tau, geqrf_scratchpad, geqrf_scratchpad_size, depends);
685
-
683
+ event = mkl_lapack::geqrf (q, size_m, size_n, in_a, lda, tau, geqrf_scratchpad, geqrf_scratchpad_size, depends);
686
684
event.wait ();
687
685
688
- verbose_print (" oneapi::mkl::lapack::geqrf" , depends.front (), event);
686
+ if (!depends.empty ()) {
687
+ verbose_print (" oneapi::mkl::lapack::geqrf" , depends.front (), event);
688
+ }
689
689
690
690
sycl::free (geqrf_scratchpad, q);
691
691
692
692
// R
693
- for (size_t i = 0 ; i < size_m; ++i)
693
+ size_t mrefl = min_size_m_n;
694
+ for (size_t i = 0 ; i < mrefl; ++i)
694
695
{
695
696
for (size_t j = 0 ; j < size_n; ++j)
696
697
{
@@ -706,37 +707,30 @@ DPCTLSyclEventRef dpnp_qr_c(DPCTLSyclQueueRef q_ref,
706
707
}
707
708
708
709
// Q
709
- const size_t nrefl = std::min< size_t >(size_m, size_n) ;
710
+ const size_t nrefl = min_size_m_n ;
710
711
const std::int64_t orgqr_scratchpad_size =
711
- mkl_lapack::orgqr_scratchpad_size<_ComputeDT>(q, size_m, size_m , nrefl, lda);
712
+ mkl_lapack::orgqr_scratchpad_size<_ComputeDT>(q, size_m, nrefl , nrefl, lda);
712
713
713
714
_ComputeDT* orgqr_scratchpad =
714
715
reinterpret_cast <_ComputeDT*>(sycl::malloc_shared (orgqr_scratchpad_size * sizeof (_ComputeDT), q));
715
716
716
- depends.clear ();
717
717
set_barrier_event (q, depends);
718
718
719
- event = mkl_lapack::orgqr (
720
- q, size_m, size_m, nrefl, in_a, lda, tau, orgqr_scratchpad, orgqr_scratchpad_size, depends);
721
-
719
+ event =
720
+ mkl_lapack::orgqr (q, size_m, nrefl, nrefl, in_a, lda, tau, orgqr_scratchpad, orgqr_scratchpad_size, depends);
722
721
event.wait ();
723
722
724
- verbose_print (" oneapi::mkl::lapack::orgqr" , depends.front (), event);
723
+ if (!depends.empty ()) {
724
+ verbose_print (" oneapi::mkl::lapack::orgqr" , depends.front (), event);
725
+ }
725
726
726
727
sycl::free (orgqr_scratchpad, q);
727
728
728
729
for (size_t i = 0 ; i < size_m; ++i)
729
730
{
730
- for (size_t j = 0 ; j < size_m ; ++j)
731
+ for (size_t j = 0 ; j < nrefl ; ++j)
731
732
{
732
- if (j < nrefl)
733
- {
734
- res_q[i * size_m + j] = in_a[j * size_m + i];
735
- }
736
- else
737
- {
738
- res_q[i * size_m + j] = _ComputeDT (0 );
739
- }
733
+ res_q[i * nrefl + j] = in_a[j * size_m + i];
740
734
}
741
735
}
742
736
0 commit comments