|
18 | 18 |
|
19 | 19 | from .array_helpers import assert_exactly_equal, ndindex, asarray
|
20 | 20 | from .hypothesis_helpers import (xps, dtypes, shapes, kwargs, matrix_shapes,
|
21 |
| - square_matrix_shapes, |
| 21 | + square_matrix_shapes, symmetric_matrices, |
22 | 22 | positive_definite_matrices, MAX_ARRAY_SIZE)
|
23 | 23 |
|
24 | 24 | from . import _array_module
|
@@ -122,11 +122,36 @@ def true_diag(x_stack):
|
122 | 122 | _test_stacks(_array_module.linalg.diagonal, x, kw, res, dims=1, true_val=true_diag)
|
123 | 123 |
|
124 | 124 | @given(
|
125 |
| - x=xps.arrays(dtype=xps.floating_dtypes(), shape=shapes), |
| 125 | + x=symmetric_matrices(finite=True), |
126 | 126 | )
|
127 | 127 | def test_eigh(x):
|
128 |
| - # res = _array_module.linalg.eigh(x) |
129 |
| - pass |
| 128 | + res = _array_module.linalg.eigh(x) |
| 129 | + |
| 130 | + # TODO: Factor out namedtuple checks |
| 131 | + |
| 132 | + # isinstance(namedtuple) doesn't work |
| 133 | + assert isinstance(res, tuple), "eigh() did not return a tuple" |
| 134 | + assert len(res) == 2, "eigh() result tuple not the correct length" |
| 135 | + assert hasattr(res, 'eigenvalues'), "eigh() result namedtuple doesn't have the 'eigenvalues' field" |
| 136 | + assert hasattr(res, 'eigenvectors'), "eigh() result namedtuple doesn't have the 'eigenvectors' field" |
| 137 | + |
| 138 | + eigenvalues = res.eigenvalues |
| 139 | + eigenvectors = res.eigenvectors |
| 140 | + |
| 141 | + assert_exactly_equal(res[0], eigenvalues) |
| 142 | + assert_exactly_equal(res[1], eigenvectors) |
| 143 | + |
| 144 | + assert eigenvalues.dtype == x.dtype, "eigh().eigenvalues did not return the correct dtype" |
| 145 | + assert eigenvalues.shape == x.shape[:-1], "eigh().eigenvalues did not return the correct shape" |
| 146 | + |
| 147 | + assert eigenvectors.dtype == x.dtype, "eigh().eigenvectors did not return the correct dtype" |
| 148 | + assert eigenvectors.shape == x.shape, "eigh().eigenvectors did not return the correct shape" |
| 149 | + |
| 150 | + _test_stacks(lambda x: _array_module.linalg.eigh(x).eigenvalues, x, {}, eigenvalues, dims=1) |
| 151 | + _test_stacks(lambda x: _array_module.linalg.eigh(x).eigenvectors, x, {}, eigenvectors, dims=2) |
| 152 | + |
| 153 | + # TODO: Test that res actually corresponds to the eigenvalues and |
| 154 | + # eigenvectors of x |
130 | 155 |
|
131 | 156 | @given(
|
132 | 157 | x=xps.arrays(dtype=xps.floating_dtypes(), shape=shapes),
|
|
0 commit comments