Skip to content

Commit 9b893ca

Browse files
committed
Factor namedtuple checks into a helper
1 parent 3711f3a commit 9b893ca

File tree

1 file changed

+15
-10
lines changed

1 file changed

+15
-10
lines changed

array_api_tests/test_linalg.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,20 @@ def _test_stacks(f, x, kw, res=None, dims=2, true_val=None):
4747
if true_val:
4848
assert_exactly_equal(decomp_res_stack, true_val(x_stack))
4949

50+
def _test_namedtuple(res, fields, func_name):
51+
"""
52+
Test that res is a namedtuple with the correct fields.
53+
"""
54+
# isinstance(namedtuple) doesn't work, and it could be either
55+
# collections.namedtuple or typing.NamedTuple. So we just check that it is
56+
# a tuple subclass with the right fields in the right order.
57+
58+
assert isinstance(res, tuple), f"{func_name}() did not return a tuple"
59+
assert len(res) == len(fields), f"{func_name}() result tuple not the correct length (should have {len(fields)} elements)"
60+
for i, field in enumerate(fields):
61+
assert hasattr(res, field), f"{func_name}() result namedtuple doesn't have the '{field}' field"
62+
assert res[i] is getattr(res, field), f"{func_name}() result namedtuple '{field}' field is not in position {i}"
63+
5064
@given(
5165
x=positive_definite_matrices(),
5266
kw=kwargs(upper=booleans())
@@ -127,20 +141,11 @@ def true_diag(x_stack):
127141
def test_eigh(x):
128142
res = _array_module.linalg.eigh(x)
129143

130-
# TODO: Factor out namedtuple checks
131-
132-
# isinstance(namedtuple) doesn't work
133-
assert isinstance(res, tuple), "eigh() did not return a tuple"
134-
assert len(res) == 2, "eigh() result tuple not the correct length"
135-
assert hasattr(res, 'eigenvalues'), "eigh() result namedtuple doesn't have the 'eigenvalues' field"
136-
assert hasattr(res, 'eigenvectors'), "eigh() result namedtuple doesn't have the 'eigenvectors' field"
144+
_test_namedtuple(res, ['eigenvalues', 'eigenvectors'], 'eigh')
137145

138146
eigenvalues = res.eigenvalues
139147
eigenvectors = res.eigenvectors
140148

141-
assert_exactly_equal(res[0], eigenvalues)
142-
assert_exactly_equal(res[1], eigenvectors)
143-
144149
assert eigenvalues.dtype == x.dtype, "eigh().eigenvalues did not return the correct dtype"
145150
assert eigenvalues.shape == x.shape[:-1], "eigh().eigenvalues did not return the correct shape"
146151

0 commit comments

Comments
 (0)