Skip to content

Commit 260b171

Browse files
Update TestEigenvalue
1 parent 6eab947 commit 260b171

File tree

1 file changed

+41
-2
lines changed

1 file changed

+41
-2
lines changed

tests/test_linalg.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -385,7 +385,46 @@ def test_eig_arange(type, size):
385385
assert_allclose(dpnp_vec, np_vec, rtol=1e-05, atol=1e-05)
386386

387387

388-
class TestEigenvalueSymm:
388+
class TestEigenvalue:
389+
@pytest.mark.parametrize(
390+
"func",
391+
[
392+
"eigh",
393+
"eigvalsh",
394+
],
395+
)
396+
@pytest.mark.parametrize(
397+
"shape",
398+
[(2, 2), (2, 3, 3), (2, 2, 3, 3)],
399+
ids=["(2,2)", "(2,3,3)", "(2,2,3,3)"],
400+
)
401+
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True))
402+
@pytest.mark.parametrize(
403+
"order",
404+
[
405+
"C",
406+
"F",
407+
],
408+
)
409+
def test_eigenvalues(self, func, shape, dtype, order):
410+
a = generate_random_numpy_array(
411+
shape, dtype, hermitian=True, seed_value=81
412+
)
413+
a_order = numpy.array(a, order=order)
414+
a_dp = inp.array(a, order=order)
415+
416+
if func == "eigh":
417+
w, v = numpy.linalg.eigh(a_order)
418+
w_dp, v_dp = inp.linalg.eigh(a_dp)
419+
420+
assert_dtype_allclose(v_dp, v)
421+
422+
else: # eighvalsh
423+
w = numpy.linalg.eigvalsh(a_order)
424+
w_dp = inp.linalg.eigvalsh(a_dp)
425+
426+
assert_dtype_allclose(w_dp, w)
427+
389428
@pytest.mark.parametrize(
390429
"func",
391430
[
@@ -410,7 +449,7 @@ def test_eigenvalue_errors(self, func):
410449
assert_raises(inp.linalg.LinAlgError, dpnp_func, a_dp)
411450

412451
# invalid UPLO
413-
assert_raises(ValueError, dpnp_func, a_dp, "N")
452+
assert_raises(ValueError, dpnp_func, a_dp, UPLO="N")
414453

415454

416455
@pytest.mark.parametrize("type", get_all_dtypes(no_bool=True, no_complex=True))

0 commit comments

Comments
 (0)