Skip to content

Commit 87825f5

Browse files
Address remarks
1 parent 96f243c commit 87825f5

File tree

5 files changed

+40
-60
lines changed

5 files changed

+40
-60
lines changed

dpnp/backend/extensions/lapack/gesvd.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ static gesvd_impl_fn_ptr_t gesvd_dispatch_table[dpctl_td_ns::num_types]
6666

6767
// Converts a given character code (ord) to the corresponding
6868
// oneapi::mkl::jobsvd enumeration value
69-
oneapi::mkl::jobsvd process_job(std::int8_t job_val)
69+
static oneapi::mkl::jobsvd process_job(std::int8_t job_val)
7070
{
7171
switch (job_val) {
7272
case 'A':

dpnp/dpnp_algo/dpnp_algo.pxd

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,6 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na
187187
DPNP_FN_SORT_EXT
188188
DPNP_FN_SUM
189189
DPNP_FN_SUM_EXT
190-
DPNP_FN_SVD
191190
DPNP_FN_TRACE
192191
DPNP_FN_TRACE_EXT
193192
DPNP_FN_TRANSPOSE

dpnp/linalg/dpnp_iface_linalg.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -646,6 +646,8 @@ def svd(a, full_matrices=True, compute_uv=True, hermitian=False):
646646
check_stacked_2d(a)
647647

648648
if hermitian is True:
649-
raise ValueError("The hermitian argument is only supported as False")
649+
raise NotImplementedError(
650+
"hermitian keyword argument is only supported with its default value."
651+
)
650652

651653
return dpnp_svd(a, full_matrices, compute_uv)

dpnp/linalg/dpnp_utils_linalg.py

Lines changed: 35 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -595,36 +595,29 @@ def dpnp_svd(a, full_matrices=True, compute_uv=True):
595595
)
596596
if compute_uv:
597597
if full_matrices:
598-
u = dpnp.eye(
599-
n,
600-
dtype=uv_type,
601-
usm_type=a_usm_type,
602-
sycl_queue=a_sycl_queue,
603-
)
604-
vt = dpnp.eye(
605-
m,
606-
dtype=uv_type,
607-
usm_type=a_usm_type,
608-
sycl_queue=a_sycl_queue,
609-
)
598+
u_shape = (n,)
599+
vt_shape = (m,)
610600
else:
611-
u = dpnp.empty(
612-
(n, 0),
613-
dtype=uv_type,
614-
usm_type=a_usm_type,
615-
sycl_queue=a_sycl_queue,
616-
)
617-
vt = dpnp.empty(
618-
(0, m),
619-
dtype=uv_type,
620-
usm_type=a_usm_type,
621-
sycl_queue=a_sycl_queue,
622-
)
601+
u_shape = (n, 0)
602+
vt_shape = (0, m)
603+
604+
u = dpnp.eye(
605+
*u_shape,
606+
dtype=uv_type,
607+
usm_type=a_usm_type,
608+
sycl_queue=a_sycl_queue,
609+
)
610+
vt = dpnp.eye(
611+
*vt_shape,
612+
dtype=uv_type,
613+
usm_type=a_usm_type,
614+
sycl_queue=a_sycl_queue,
615+
)
623616
return u, s, vt
624617
else:
625618
return s
626619

627-
# `a`` must be copied because gesvd destroys the input matrix
620+
# `a` must be copied because gesvd destroys the input matrix
628621
# `a` must be traspotted if m < n
629622
if m >= n:
630623
x = a
@@ -647,46 +640,32 @@ def dpnp_svd(a, full_matrices=True, compute_uv=True):
647640
k = n # = min(m, n) where m >= n is ensured above
648641
if compute_uv:
649642
if full_matrices:
650-
u_h = dpnp.empty(
651-
(m, m),
652-
dtype=uv_type,
653-
usm_type=a_usm_type,
654-
sycl_queue=a_sycl_queue,
655-
)
656-
vt_h = dpnp.empty(
657-
(n, n),
658-
dtype=uv_type,
659-
usm_type=a_usm_type,
660-
sycl_queue=a_sycl_queue,
661-
)
643+
u_shape = (m, m)
644+
vt_shape = (n, n)
662645
jobu = ord("A")
663646
jobvt = ord("A")
664647
else:
665-
u_h = dpnp.empty_like(x, dtype=uv_type)
666-
vt_h = dpnp.empty(
667-
(k, n),
668-
dtype=uv_type,
669-
usm_type=a_usm_type,
670-
sycl_queue=a_sycl_queue,
671-
)
648+
u_shape = x.shape
649+
vt_shape = (k, n)
672650
jobu = ord("S")
673651
jobvt = ord("S")
674652
else:
675-
u_h = dpnp.empty(
676-
[],
677-
dtype=uv_type,
678-
usm_type=a_usm_type,
679-
sycl_queue=a_sycl_queue,
680-
)
681-
vt_h = dpnp.empty(
682-
[],
683-
dtype=uv_type,
684-
usm_type=a_usm_type,
685-
sycl_queue=a_sycl_queue,
686-
)
653+
u_shape = vt_shape = ()
687654
jobu = ord("N")
688655
jobvt = ord("N")
689656

657+
u_h = dpnp.empty(
658+
u_shape,
659+
dtype=uv_type,
660+
usm_type=a_usm_type,
661+
sycl_queue=a_sycl_queue,
662+
)
663+
vt_h = dpnp.empty(
664+
vt_shape,
665+
dtype=uv_type,
666+
usm_type=a_usm_type,
667+
sycl_queue=a_sycl_queue,
668+
)
690669
s_h = dpnp.empty(
691670
k, dtype=s_type, usm_type=a_usm_type, sycl_queue=a_sycl_queue
692671
)

tests/test_linalg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -644,7 +644,7 @@ def test_svd_errors(self):
644644
assert_raises(TypeError, inp.linalg.svd, a_np)
645645

646646
# unsupported hermitian argument
647-
assert_raises(ValueError, inp.linalg.svd, a_dp, hermitian=True)
647+
assert_raises(NotImplementedError, inp.linalg.svd, a_dp, hermitian=True)
648648

649649
# a.ndim < 2
650650
a_dp_ndim_1 = a_dp.flatten()

0 commit comments

Comments
 (0)