Skip to content

Commit 6b63eea

Browse files
Add test_svd_hermitian
1 parent 0be3132 commit 6b63eea

File tree

1 file changed

+106
-56
lines changed

1 file changed

+106
-56
lines changed

tests/test_linalg.py

Lines changed: 106 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from .helper import (
99
assert_dtype_allclose,
1010
get_all_dtypes,
11+
get_complex_dtypes,
1112
has_support_aspect64,
1213
is_cpu_device,
1314
)
@@ -563,6 +564,76 @@ def test_solve_errors(self):
563564

564565

565566
class TestSvd:
567+
def set_tol(self, dtype):
568+
tol = 1e-06
569+
if dtype in (inp.float32, inp.complex64):
570+
tol = 1e-05
571+
elif not has_support_aspect64() and dtype in (
572+
inp.int32,
573+
inp.int64,
574+
None,
575+
):
576+
tol = 1e-05
577+
return tol
578+
579+
def check_types_shapes(
580+
self, dp_u, dp_s, dp_vt, np_u, np_s, np_vt, compute_vt=True
581+
):
582+
if has_support_aspect64():
583+
if compute_vt:
584+
assert dp_u.dtype == np_u.dtype
585+
assert dp_vt.dtype == np_vt.dtype
586+
assert dp_s.dtype == np_s.dtype
587+
else:
588+
if compute_vt:
589+
assert dp_u.dtype.kind == np_u.dtype.kind
590+
assert dp_vt.dtype.kind == np_vt.dtype.kind
591+
assert dp_s.dtype.kind == np_s.dtype.kind
592+
593+
if compute_vt:
594+
assert dp_u.shape == np_u.shape
595+
assert dp_vt.shape == np_vt.shape
596+
assert dp_s.shape == np_s.shape
597+
598+
def check_decomposition(
599+
self, dp_a, dp_u, dp_s, dp_vt, np_u, np_s, np_vt, compute_vt, tol
600+
):
601+
if compute_vt:
602+
dpnp_diag_s = inp.zeros_like(dp_a, dtype=dp_s.dtype)
603+
for i in range(min(dp_a.shape[-2], dp_a.shape[-1])):
604+
dpnp_diag_s[..., i, i] = dp_s[..., i]
605+
# TODO: remove it when dpnp.dot is updated
606+
# dpnp.dot does not support complex type
607+
if inp.issubdtype(dp_a.dtype, inp.complexfloating):
608+
reconstructed = numpy.dot(
609+
inp.asnumpy(dp_u),
610+
numpy.dot(inp.asnumpy(dpnp_diag_s), inp.asnumpy(dp_vt)),
611+
)
612+
else:
613+
reconstructed = inp.dot(dp_u, inp.dot(dpnp_diag_s, dp_vt))
614+
assert_allclose(dp_a, reconstructed, rtol=tol, atol=tol)
615+
616+
assert_allclose(dp_s, np_s, rtol=tol, atol=1e-03)
617+
618+
if compute_vt:
619+
for i in range(min(dp_a.shape[-2], dp_a.shape[-1])):
620+
if np_u[..., 0, i] * dp_u[..., 0, i] < 0:
621+
np_u[..., :, i] = -np_u[..., :, i]
622+
np_vt[..., i, :] = -np_vt[..., i, :]
623+
for i in range(numpy.count_nonzero(np_s > tol)):
624+
assert_allclose(
625+
inp.asnumpy(dp_u[..., :, i]),
626+
np_u[..., :, i],
627+
rtol=tol,
628+
atol=tol,
629+
)
630+
assert_allclose(
631+
inp.asnumpy(dp_vt[..., i, :]),
632+
np_vt[..., i, :],
633+
rtol=tol,
634+
atol=tol,
635+
)
636+
566637
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True))
567638
@pytest.mark.parametrize(
568639
"shape",
@@ -571,71 +642,50 @@ class TestSvd:
571642
)
572643
def test_svd(self, dtype, shape):
573644
a = numpy.arange(shape[0] * shape[1], dtype=dtype).reshape(shape)
574-
ia = inp.array(a)
645+
dp_a = inp.array(a)
575646

576647
np_u, np_s, np_vt = numpy.linalg.svd(a)
577-
dpnp_u, dpnp_s, dpnp_vt = inp.linalg.svd(ia)
578-
579-
support_aspect64 = has_support_aspect64()
648+
dp_u, dp_s, dp_vt = inp.linalg.svd(dp_a)
580649

581-
if support_aspect64:
582-
assert dpnp_u.dtype == np_u.dtype
583-
assert dpnp_s.dtype == np_s.dtype
584-
assert dpnp_vt.dtype == np_vt.dtype
585-
586-
assert dpnp_u.shape == np_u.shape
587-
assert dpnp_s.shape == np_s.shape
588-
assert dpnp_vt.shape == np_vt.shape
650+
self.check_types_shapes(dp_u, dp_s, dp_vt, np_u, np_s, np_vt)
651+
tol = self.set_tol(dtype)
652+
self.check_decomposition(
653+
dp_a, dp_u, dp_s, dp_vt, np_u, np_s, np_vt, True, tol
654+
)
589655

590-
tol = 1e-06
591-
if dtype in (inp.float32, inp.complex64):
592-
tol = 1e-05
593-
elif not support_aspect64 and dtype in (inp.int32, inp.int64, None):
594-
tol = 1e-05
656+
@pytest.mark.parametrize("dtype", get_complex_dtypes())
657+
@pytest.mark.parametrize("compute_vt", [True, False], ids=["True", "False"])
658+
@pytest.mark.parametrize(
659+
"shape",
660+
[(2, 2), (16, 16)],
661+
ids=["(2,2)", "(16, 16)"],
662+
)
663+
def test_svd_hermitian(self, dtype, compute_vt, shape):
664+
a = numpy.random.randn(*shape) + 1j * numpy.random.randn(*shape)
665+
a = numpy.conj(a.T) @ a
595666

596-
# check decomposition
597-
dpnp_diag_s = inp.zeros(shape, dtype=dpnp_s.dtype)
598-
for i in range(dpnp_s.size):
599-
dpnp_diag_s[i, i] = dpnp_s[i]
667+
a = a.astype(dtype)
668+
dp_a = inp.array(a)
600669

601-
# check decomposition
602-
# TODO: remove it when dpnp.dot is updated
603-
# dpnp.dot does not support complex type
604-
if inp.issubdtype(dtype, inp.complexfloating):
605-
assert_allclose(
606-
inp.asnumpy(ia),
607-
numpy.dot(
608-
inp.asnumpy(dpnp_u),
609-
numpy.dot(inp.asnumpy(dpnp_diag_s), inp.asnumpy(dpnp_vt)),
610-
),
611-
rtol=tol,
612-
atol=tol,
670+
if compute_vt:
671+
np_u, np_s, np_vt = numpy.linalg.svd(
672+
a, compute_uv=compute_vt, hermitian=True
613673
)
614-
else:
615-
assert_allclose(
616-
ia,
617-
inp.dot(dpnp_u, inp.dot(dpnp_diag_s, dpnp_vt)),
618-
rtol=tol,
619-
atol=tol,
674+
dp_u, dp_s, dp_vt = inp.linalg.svd(
675+
dp_a, compute_uv=compute_vt, hermitian=True
620676
)
677+
else:
678+
np_s = numpy.linalg.svd(a, compute_uv=compute_vt, hermitian=True)
679+
dp_s = inp.linalg.svd(dp_a, compute_uv=compute_vt, hermitian=True)
680+
np_u = np_vt = dp_u = dp_vt = None
621681

622-
# compare singular values
623-
assert_allclose(dpnp_s, np_s, rtol=tol, atol=1e-03)
624-
625-
# change sign of vectors
626-
for i in range(min(shape[0], shape[1])):
627-
if np_u[0, i] * dpnp_u[0, i] < 0:
628-
np_u[:, i] = -np_u[:, i]
629-
np_vt[i, :] = -np_vt[i, :]
630-
631-
# compare vectors for non-zero values
632-
for i in range(numpy.count_nonzero(np_s > tol)):
633-
assert_allclose(
634-
inp.asnumpy(dpnp_u)[:, i], np_u[:, i], rtol=tol, atol=tol
635-
)
636-
assert_allclose(
637-
inp.asnumpy(dpnp_vt)[i, :], np_vt[i, :], rtol=tol, atol=tol
638-
)
682+
self.check_types_shapes(
683+
dp_u, dp_s, dp_vt, np_u, np_s, np_vt, compute_vt
684+
)
685+
tol = self.set_tol(dtype)
686+
self.check_decomposition(
687+
dp_a, dp_u, dp_s, dp_vt, np_u, np_s, np_vt, compute_vt, tol
688+
)
639689

640690
def test_svd_errors(self):
641691
a_dp = inp.array([[1, 2], [3, 4]], dtype="float32")

0 commit comments

Comments
 (0)