@@ -47,6 +47,20 @@ def _test_stacks(f, x, kw, res=None, dims=2, true_val=None):
47
47
if true_val :
48
48
assert_exactly_equal (decomp_res_stack , true_val (x_stack ))
49
49
50
+ def _test_namedtuple (res , fields , func_name ):
51
+ """
52
+ Test that res is a namedtuple with the correct fields.
53
+ """
54
+ # isinstance(namedtuple) doesn't work, and it could be either
55
+ # collections.namedtuple or typing.NamedTuple. So we just check that it is
56
+ # a tuple subclass with the right fields in the right order.
57
+
58
+ assert isinstance (res , tuple ), f"{ func_name } () did not return a tuple"
59
+ assert len (res ) == len (fields ), f"{ func_name } () result tuple not the correct length (should have { len (fields )} elements)"
60
+ for i , field in enumerate (fields ):
61
+ assert hasattr (res , field ), f"{ func_name } () result namedtuple doesn't have the '{ field } ' field"
62
+ assert res [i ] is getattr (res , field ), f"{ func_name } () result namedtuple '{ field } ' field is not in position { i } "
63
+
50
64
@given (
51
65
x = positive_definite_matrices (),
52
66
kw = kwargs (upper = booleans ())
@@ -127,20 +141,11 @@ def true_diag(x_stack):
127
141
def test_eigh (x ):
128
142
res = _array_module .linalg .eigh (x )
129
143
130
- # TODO: Factor out namedtuple checks
131
-
132
- # isinstance(namedtuple) doesn't work
133
- assert isinstance (res , tuple ), "eigh() did not return a tuple"
134
- assert len (res ) == 2 , "eigh() result tuple not the correct length"
135
- assert hasattr (res , 'eigenvalues' ), "eigh() result namedtuple doesn't have the 'eigenvalues' field"
136
- assert hasattr (res , 'eigenvectors' ), "eigh() result namedtuple doesn't have the 'eigenvectors' field"
144
+ _test_namedtuple (res , ['eigenvalues' , 'eigenvectors' ], 'eigh' )
137
145
138
146
eigenvalues = res .eigenvalues
139
147
eigenvectors = res .eigenvectors
140
148
141
- assert_exactly_equal (res [0 ], eigenvalues )
142
- assert_exactly_equal (res [1 ], eigenvectors )
143
-
144
149
assert eigenvalues .dtype == x .dtype , "eigh().eigenvalues did not return the correct dtype"
145
150
assert eigenvalues .shape == x .shape [:- 1 ], "eigh().eigenvalues did not return the correct shape"
146
151
0 commit comments