Skip to content

Commit 3711f3a

Browse files
committed
Implement test_eigh()
1 parent 9c606c8 commit 3711f3a

File tree

1 file changed

+29
-4
lines changed

1 file changed

+29
-4
lines changed

array_api_tests/test_linalg.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
from .array_helpers import assert_exactly_equal, ndindex, asarray
2020
from .hypothesis_helpers import (xps, dtypes, shapes, kwargs, matrix_shapes,
21-
square_matrix_shapes,
21+
square_matrix_shapes, symmetric_matrices,
2222
positive_definite_matrices, MAX_ARRAY_SIZE)
2323

2424
from . import _array_module
@@ -122,11 +122,36 @@ def true_diag(x_stack):
122122
_test_stacks(_array_module.linalg.diagonal, x, kw, res, dims=1, true_val=true_diag)
123123

124124
@given(
125-
x=xps.arrays(dtype=xps.floating_dtypes(), shape=shapes),
125+
x=symmetric_matrices(finite=True),
126126
)
127127
def test_eigh(x):
128-
# res = _array_module.linalg.eigh(x)
129-
pass
128+
res = _array_module.linalg.eigh(x)
129+
130+
# TODO: Factor out namedtuple checks
131+
132+
# isinstance(namedtuple) doesn't work
133+
assert isinstance(res, tuple), "eigh() did not return a tuple"
134+
assert len(res) == 2, "eigh() result tuple not the correct length"
135+
assert hasattr(res, 'eigenvalues'), "eigh() result namedtuple doesn't have the 'eigenvalues' field"
136+
assert hasattr(res, 'eigenvectors'), "eigh() result namedtuple doesn't have the 'eigenvectors' field"
137+
138+
eigenvalues = res.eigenvalues
139+
eigenvectors = res.eigenvectors
140+
141+
assert_exactly_equal(res[0], eigenvalues)
142+
assert_exactly_equal(res[1], eigenvectors)
143+
144+
assert eigenvalues.dtype == x.dtype, "eigh().eigenvalues did not return the correct dtype"
145+
assert eigenvalues.shape == x.shape[:-1], "eigh().eigenvalues did not return the correct shape"
146+
147+
assert eigenvectors.dtype == x.dtype, "eigh().eigenvectors did not return the correct dtype"
148+
assert eigenvectors.shape == x.shape, "eigh().eigenvectors did not return the correct shape"
149+
150+
_test_stacks(lambda x: _array_module.linalg.eigh(x).eigenvalues, x, {}, eigenvalues, dims=1)
151+
_test_stacks(lambda x: _array_module.linalg.eigh(x).eigenvectors, x, {}, eigenvectors, dims=2)
152+
153+
# TODO: Test that res actually corresponds to the eigenvalues and
154+
# eigenvectors of x
130155

131156
@given(
132157
x=xps.arrays(dtype=xps.floating_dtypes(), shape=shapes),

0 commit comments

Comments
 (0)