Skip to content

Commit a7bb8e0

Browse files
Update test_svd_hermitian
1 parent e7b899a commit a7bb8e0

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

tests/test_linalg.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1122,18 +1122,19 @@ def test_svd(self, dtype, shape):
11221122
dp_a, dp_u, dp_s, dp_vt, np_u, np_s, np_vt, True
11231123
)
11241124

1125-
@pytest.mark.parametrize("dtype", get_complex_dtypes())
1125+
@pytest.mark.parametrize("dtype", get_float_complex_dtypes())
11261126
@pytest.mark.parametrize("compute_vt", [True, False], ids=["True", "False"])
11271127
@pytest.mark.parametrize(
11281128
"shape",
11291129
[(2, 2), (16, 16)],
1130-
ids=["(2,2)", "(16, 16)"],
1130+
ids=["(2, 2)", "(16, 16)"],
11311131
)
11321132
def test_svd_hermitian(self, dtype, compute_vt, shape):
1133-
a = numpy.random.randn(*shape) + 1j * numpy.random.randn(*shape)
1134-
a = numpy.conj(a.T) @ a
1133+
a = numpy.random.randn(*shape).astype(dtype)
1134+
if numpy.issubdtype(dtype, numpy.complexfloating):
1135+
a += 1j * numpy.random.randn(*shape)
1136+
a = (a + a.conj().T) / 2
11351137

1136-
a = a.astype(dtype)
11371138
dp_a = inp.array(a)
11381139

11391140
if compute_vt:

0 commit comments

Comments
 (0)